当前位置: 首页 > news >正文

NLP实战(4):使用PyTorch构建LSTM模型预测糖尿病

目录

1. 数据准备

2. 创建数据加载器

3. 构建LSTM模型

4. 模型训练

5. 模型评估

6. 可视化训练过程

7.总结

8.实验过程和下载


在这篇博客中,我将详细介绍如何使用PyTorch构建一个双层LSTM模型来预测糖尿病。

我们将从数据加载开始,逐步讲解模型构建、训练过程和结果评估。

1. 数据准备

首先,我们需要加载并准备数据:

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt# 加载数据
data = pd.read_csv('diabetes.csv', header=None)
X = data.iloc[:, :-1].values  # 特征
y = data.iloc[:, -1].values   # 标签# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 转换为PyTorch张量
X_train = torch.FloatTensor(X_train)  # 形状为 (样本数, 8)
X_test = torch.FloatTensor(X_test)    # 形状为 (样本数, 8)
y_train = torch.FloatTensor(y_train)
y_test = torch.FloatTensor(y_test)

这段代码完成了以下工作:

  1. 导入必要的库

  2. 从CSV文件加载糖尿病数据集

  3. 将数据分为特征(X)和标签(y)

  4. 使用train_test_split将数据划分为训练集和测试集(80%训练,20%测试)

  5. 将NumPy数组转换为PyTorch张量

2. 创建数据加载器

为了高效地批量加载数据,我们使用PyTorch的DataLoader:

# 创建DataLoader
train_data = TensorDataset(X_train, y_train)
test_data = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32)

这里我们:

  • 使用TensorDataset将特征和标签打包

  • 创建训练和测试的DataLoader,批量大小为32

  • 训练数据会被随机打乱(shuffle=True),而测试数据保持原顺序

3. 构建LSTM模型

我们构建了一个双层LSTM模型:

class LSTMModel(nn.Module):def __init__(self, input_size=8, hidden_size1=64, hidden_size2=32):super(LSTMModel, self).__init__()self.lstm1 = nn.LSTM(input_size, hidden_size1, batch_first=True)self.dropout1 = nn.Dropout(0.3)self.lstm2 = nn.LSTM(hidden_size1, hidden_size2, batch_first=True)self.dropout2 = nn.Dropout(0.3)self.fc = nn.Linear(hidden_size2, 1)def forward(self, x):# 添加序列长度维度 (batch_size, 1, input_size)x = x.unsqueeze(1)  # 从(batch_size, 8)变为(batch_size, 1, 8)# 第一层LSTMx, _ = self.lstm1(x)x = self.dropout1(x)# 第二层LSTMx, (hn, cn) = self.lstm2(x)x = self.dropout2(hn[-1])  # 取最后一个时间步的隐藏状态x = self.fc(x)return torch.sigmoid(x.squeeze())

模型特点:

  • 输入特征数为8(对应糖尿病数据集的8个特征)

  • 第一层LSTM有64个隐藏单元

  • 第二层LSTM有32个隐藏单元

  • 每层LSTM后都有dropout层(概率0.3)防止过拟合

  • 最后通过一个全连接层输出单个值,并用sigmoid激活函数转换为概率

  • 在forward方法中,我们添加了一个序列长度维度(1),因为LSTM需要序列数据

4. 模型训练

我们使用Adam优化器和BCELoss(二元交叉熵损失)来训练模型:

# 初始化模型
model = LSTMModel(input_size=8)  # 8个特征
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0002)# 训练和验证记录
train_losses = []
train_accs = []
val_losses = []
val_accs = []# 训练模型
epochs = 300
for epoch in range(epochs):model.train()running_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()predicted = (outputs > 0.5).float()total += labels.size(0)correct += (predicted == labels).sum().item()# 计算并记录训练指标train_loss = running_loss / len(train_loader)train_acc = correct / totaltrain_losses.append(train_loss)train_accs.append(train_acc)# 验证model.eval()val_loss = 0.0correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)loss = criterion(outputs, labels)val_loss += loss.item()predicted = (outputs > 0.5).float()total += labels.size(0)correct += (predicted == labels).sum().item()# 计算并记录验证指标val_loss = val_loss / len(test_loader)val_acc = correct / totalval_losses.append(val_loss)val_accs.append(val_acc)print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

训练过程包括:

  1. 初始化模型、损失函数和优化器

  2. 进行300个epoch的训练

  3. 每个epoch中:

    • 训练阶段:前向传播、计算损失、反向传播、参数更新

    • 验证阶段:评估模型在测试集上的表现

  4. 记录并打印训练和验证的损失和准确率

5. 模型评估

训练完成后,我们评估模型在测试集上的最终表现:

# 评估模型
model.eval()
with torch.no_grad():outputs = model(X_test)predicted = (outputs > 0.5).float()accuracy = (predicted == y_test).float().mean()
print(f'Test Accuracy: {accuracy:.4f}')

6. 可视化训练过程

最后,我们绘制训练和验证的准确率和损失曲线:

# 绘制训练曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_accs, label='Training Accuracy')
plt.plot(val_accs, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

这些图表可以帮助我们:

  • 观察模型是否收敛

  • 检测是否存在过拟合或欠拟合

  • 决定是否需要调整训练参数

7.总结

在这篇博客中,我们详细介绍了如何使用PyTorch构建和训练一个双层LSTM模型来预测糖尿病。关键点包括:

  1. 数据准备和加载

  2. LSTM模型架构设计

  3. 训练过程和验证

  4. 模型评估和可视化

虽然LSTM通常用于时间序列数据,但在这个例子中我们将其应用于非时间序列数据,展示了PyTorch的灵活性。通过调整模型架构、超参数和数据预处理,可以进一步提高模型性能。

希望这篇博客能帮助你理解如何使用PyTorch实现LSTM模型!

8.实验过程和下载

日志如下:

Epoch 1/300, Train Loss: 0.7193, Train Acc: 0.3558, Val Loss: 0.7248, Val Acc: 0.3092
Epoch 2/300, Train Loss: 0.7165, Train Acc: 0.3558, Val Loss: 0.7203, Val Acc: 0.3092
Epoch 3/300, Train Loss: 0.7121, Train Acc: 0.3558, Val Loss: 0.7158, Val Acc: 0.3092
Epoch 4/300, Train Loss: 0.7087, Train Acc: 0.3558, Val Loss: 0.7108, Val Acc: 0.3092
Epoch 5/300, Train Loss: 0.7042, Train Acc: 0.3558, Val Loss: 0.7053, Val Acc: 0.3092
Epoch 6/300, Train Loss: 0.7003, Train Acc: 0.3624, Val Loss: 0.6989, Val Acc: 0.3092
Epoch 7/300, Train Loss: 0.6951, Train Acc: 0.4498, Val Loss: 0.6920, Val Acc: 0.5066
Epoch 8/300, Train Loss: 0.6881, Train Acc: 0.6277, Val Loss: 0.6837, Val Acc: 0.7500
Epoch 9/300, Train Loss: 0.6823, Train Acc: 0.6590, Val Loss: 0.6734, Val Acc: 0.7105
Epoch 10/300, Train Loss: 0.6737, Train Acc: 0.6557, Val Loss: 0.6622, Val Acc: 0.6908
Epoch 11/300, Train Loss: 0.6653, Train Acc: 0.6491, Val Loss: 0.6496, Val Acc: 0.6974
Epoch 12/300, Train Loss: 0.6566, Train Acc: 0.6409, Val Loss: 0.6357, Val Acc: 0.6974
Epoch 13/300, Train Loss: 0.6457, Train Acc: 0.6458, Val Loss: 0.6215, Val Acc: 0.6908
Epoch 14/300, Train Loss: 0.6379, Train Acc: 0.6425, Val Loss: 0.6075, Val Acc: 0.6908
Epoch 15/300, Train Loss: 0.6306, Train Acc: 0.6425, Val Loss: 0.5973, Val Acc: 0.6908
Epoch 16/300, Train Loss: 0.6248, Train Acc: 0.6425, Val Loss: 0.5870, Val Acc: 0.6908
Epoch 17/300, Train Loss: 0.6203, Train Acc: 0.6442, Val Loss: 0.5778, Val Acc: 0.6908
Epoch 18/300, Train Loss: 0.6123, Train Acc: 0.6442, Val Loss: 0.5709, Val Acc: 0.6974
Epoch 19/300, Train Loss: 0.6142, Train Acc: 0.6425, Val Loss: 0.5648, Val Acc: 0.6974
Epoch 20/300, Train Loss: 0.6046, Train Acc: 0.6425, Val Loss: 0.5597, Val Acc: 0.6974
Epoch 21/300, Train Loss: 0.5988, Train Acc: 0.6425, Val Loss: 0.5547, Val Acc: 0.6974
Epoch 22/300, Train Loss: 0.5989, Train Acc: 0.6442, Val Loss: 0.5497, Val Acc: 0.6974
Epoch 23/300, Train Loss: 0.5993, Train Acc: 0.6392, Val Loss: 0.5454, Val Acc: 0.6974
Epoch 24/300, Train Loss: 0.5930, Train Acc: 0.6409, Val Loss: 0.5406, Val Acc: 0.7039
Epoch 25/300, Train Loss: 0.5872, Train Acc: 0.6392, Val Loss: 0.5362, Val Acc: 0.6974
Epoch 26/300, Train Loss: 0.5859, Train Acc: 0.6425, Val Loss: 0.5327, Val Acc: 0.6974
Epoch 27/300, Train Loss: 0.5859, Train Acc: 0.6442, Val Loss: 0.5285, Val Acc: 0.7039
Epoch 28/300, Train Loss: 0.5796, Train Acc: 0.6458, Val Loss: 0.5244, Val Acc: 0.7105
Epoch 29/300, Train Loss: 0.5778, Train Acc: 0.6524, Val Loss: 0.5212, Val Acc: 0.7171
Epoch 30/300, Train Loss: 0.5727, Train Acc: 0.6573, Val Loss: 0.5170, Val Acc: 0.7303
Epoch 31/300, Train Loss: 0.5682, Train Acc: 0.6623, Val Loss: 0.5122, Val Acc: 0.7434
Epoch 32/300, Train Loss: 0.5695, Train Acc: 0.6689, Val Loss: 0.5075, Val Acc: 0.7434
Epoch 33/300, Train Loss: 0.5667, Train Acc: 0.6771, Val Loss: 0.5044, Val Acc: 0.7566
Epoch 34/300, Train Loss: 0.5592, Train Acc: 0.6870, Val Loss: 0.4993, Val Acc: 0.7566
Epoch 35/300, Train Loss: 0.5555, Train Acc: 0.6903, Val Loss: 0.4958, Val Acc: 0.7632
Epoch 36/300, Train Loss: 0.5513, Train Acc: 0.7051, Val Loss: 0.4914, Val Acc: 0.7763
Epoch 37/300, Train Loss: 0.5483, Train Acc: 0.7035, Val Loss: 0.4870, Val Acc: 0.7829
Epoch 38/300, Train Loss: 0.5484, Train Acc: 0.7068, Val Loss: 0.4828, Val Acc: 0.7829
Epoch 39/300, Train Loss: 0.5436, Train Acc: 0.7216, Val Loss: 0.4794, Val Acc: 0.7961
Epoch 40/300, Train Loss: 0.5420, Train Acc: 0.7282, Val Loss: 0.4767, Val Acc: 0.8092
Epoch 41/300, Train Loss: 0.5353, Train Acc: 0.7216, Val Loss: 0.4727, Val Acc: 0.8289
Epoch 42/300, Train Loss: 0.5284, Train Acc: 0.7463, Val Loss: 0.4680, Val Acc: 0.8289
Epoch 43/300, Train Loss: 0.5287, Train Acc: 0.7463, Val Loss: 0.4651, Val Acc: 0.8158
Epoch 44/300, Train Loss: 0.5268, Train Acc: 0.7496, Val Loss: 0.4626, Val Acc: 0.8158
Epoch 45/300, Train Loss: 0.5204, Train Acc: 0.7529, Val Loss: 0.4592, Val Acc: 0.8158
Epoch 46/300, Train Loss: 0.5176, Train Acc: 0.7512, Val Loss: 0.4553, Val Acc: 0.8158
Epoch 47/300, Train Loss: 0.5191, Train Acc: 0.7562, Val Loss: 0.4510, Val Acc: 0.8158
Epoch 48/300, Train Loss: 0.5202, Train Acc: 0.7545, Val Loss: 0.4492, Val Acc: 0.8158
Epoch 49/300, Train Loss: 0.5073, Train Acc: 0.7611, Val Loss: 0.4473, Val Acc: 0.8158
Epoch 50/300, Train Loss: 0.5062, Train Acc: 0.7661, Val Loss: 0.4447, Val Acc: 0.8224
Epoch 51/300, Train Loss: 0.5083, Train Acc: 0.7661, Val Loss: 0.4426, Val Acc: 0.8289
Epoch 52/300, Train Loss: 0.5080, Train Acc: 0.7578, Val Loss: 0.4405, Val Acc: 0.8289
Epoch 53/300, Train Loss: 0.5068, Train Acc: 0.7595, Val Loss: 0.4389, Val Acc: 0.8092
Epoch 54/300, Train Loss: 0.4990, Train Acc: 0.7595, Val Loss: 0.4359, Val Acc: 0.8092
Epoch 55/300, Train Loss: 0.5007, Train Acc: 0.7578, Val Loss: 0.4346, Val Acc: 0.8092
Epoch 56/300, Train Loss: 0.5052, Train Acc: 0.7545, Val Loss: 0.4325, Val Acc: 0.8092
Epoch 57/300, Train Loss: 0.5023, Train Acc: 0.7562, Val Loss: 0.4327, Val Acc: 0.8026
Epoch 58/300, Train Loss: 0.4969, Train Acc: 0.7578, Val Loss: 0.4329, Val Acc: 0.7961
Epoch 59/300, Train Loss: 0.4955, Train Acc: 0.7562, Val Loss: 0.4284, Val Acc: 0.8026
Epoch 60/300, Train Loss: 0.4971, Train Acc: 0.7595, Val Loss: 0.4291, Val Acc: 0.7961
Epoch 61/300, Train Loss: 0.4928, Train Acc: 0.7545, Val Loss: 0.4271, Val Acc: 0.7961
Epoch 62/300, Train Loss: 0.4902, Train Acc: 0.7578, Val Loss: 0.4258, Val Acc: 0.7961
Epoch 63/300, Train Loss: 0.4909, Train Acc: 0.7463, Val Loss: 0.4241, Val Acc: 0.7961
Epoch 64/300, Train Loss: 0.4970, Train Acc: 0.7595, Val Loss: 0.4229, Val Acc: 0.7961
Epoch 65/300, Train Loss: 0.4892, Train Acc: 0.7595, Val Loss: 0.4234, Val Acc: 0.7961
Epoch 66/300, Train Loss: 0.4914, Train Acc: 0.7545, Val Loss: 0.4234, Val Acc: 0.7961
Epoch 67/300, Train Loss: 0.4937, Train Acc: 0.7628, Val Loss: 0.4232, Val Acc: 0.7961
Epoch 68/300, Train Loss: 0.4887, Train Acc: 0.7562, Val Loss: 0.4225, Val Acc: 0.7961
Epoch 69/300, Train Loss: 0.4890, Train Acc: 0.7562, Val Loss: 0.4214, Val Acc: 0.7961
Epoch 70/300, Train Loss: 0.4868, Train Acc: 0.7479, Val Loss: 0.4208, Val Acc: 0.7961
Epoch 71/300, Train Loss: 0.4883, Train Acc: 0.7529, Val Loss: 0.4197, Val Acc: 0.7961
Epoch 72/300, Train Loss: 0.4917, Train Acc: 0.7545, Val Loss: 0.4198, Val Acc: 0.7961
Epoch 73/300, Train Loss: 0.4849, Train Acc: 0.7628, Val Loss: 0.4182, Val Acc: 0.7961
Epoch 74/300, Train Loss: 0.4903, Train Acc: 0.7529, Val Loss: 0.4190, Val Acc: 0.7961
Epoch 75/300, Train Loss: 0.4965, Train Acc: 0.7562, Val Loss: 0.4196, Val Acc: 0.7961
Epoch 76/300, Train Loss: 0.4906, Train Acc: 0.7545, Val Loss: 0.4198, Val Acc: 0.7961
Epoch 77/300, Train Loss: 0.4893, Train Acc: 0.7529, Val Loss: 0.4189, Val Acc: 0.7961
Epoch 78/300, Train Loss: 0.4907, Train Acc: 0.7562, Val Loss: 0.4173, Val Acc: 0.7961
Epoch 79/300, Train Loss: 0.4828, Train Acc: 0.7496, Val Loss: 0.4168, Val Acc: 0.7961
Epoch 80/300, Train Loss: 0.4855, Train Acc: 0.7661, Val Loss: 0.4162, Val Acc: 0.8026
Epoch 81/300, Train Loss: 0.4880, Train Acc: 0.7578, Val Loss: 0.4169, Val Acc: 0.8026
Epoch 82/300, Train Loss: 0.4967, Train Acc: 0.7545, Val Loss: 0.4180, Val Acc: 0.7895
Epoch 83/300, Train Loss: 0.4864, Train Acc: 0.7578, Val Loss: 0.4187, Val Acc: 0.7829
Epoch 84/300, Train Loss: 0.4914, Train Acc: 0.7545, Val Loss: 0.4167, Val Acc: 0.7961
Epoch 85/300, Train Loss: 0.4818, Train Acc: 0.7595, Val Loss: 0.4154, Val Acc: 0.8026
Epoch 86/300, Train Loss: 0.4943, Train Acc: 0.7562, Val Loss: 0.4159, Val Acc: 0.8026
Epoch 87/300, Train Loss: 0.4830, Train Acc: 0.7595, Val Loss: 0.4165, Val Acc: 0.7961
Epoch 88/300, Train Loss: 0.4845, Train Acc: 0.7628, Val Loss: 0.4162, Val Acc: 0.7961
Epoch 89/300, Train Loss: 0.4790, Train Acc: 0.7611, Val Loss: 0.4163, Val Acc: 0.7961
Epoch 90/300, Train Loss: 0.4856, Train Acc: 0.7512, Val Loss: 0.4170, Val Acc: 0.7895
Epoch 91/300, Train Loss: 0.4853, Train Acc: 0.7562, Val Loss: 0.4151, Val Acc: 0.7961
Epoch 92/300, Train Loss: 0.4827, Train Acc: 0.7545, Val Loss: 0.4153, Val Acc: 0.7961
Epoch 93/300, Train Loss: 0.4887, Train Acc: 0.7661, Val Loss: 0.4175, Val Acc: 0.7895
Epoch 94/300, Train Loss: 0.4933, Train Acc: 0.7479, Val Loss: 0.4171, Val Acc: 0.7895
Epoch 95/300, Train Loss: 0.4836, Train Acc: 0.7545, Val Loss: 0.4171, Val Acc: 0.7895
Epoch 96/300, Train Loss: 0.4789, Train Acc: 0.7611, Val Loss: 0.4164, Val Acc: 0.7895
Epoch 97/300, Train Loss: 0.4831, Train Acc: 0.7529, Val Loss: 0.4159, Val Acc: 0.7895
Epoch 98/300, Train Loss: 0.4867, Train Acc: 0.7595, Val Loss: 0.4149, Val Acc: 0.7895
Epoch 99/300, Train Loss: 0.4818, Train Acc: 0.7595, Val Loss: 0.4154, Val Acc: 0.7895
Epoch 100/300, Train Loss: 0.4872, Train Acc: 0.7562, Val Loss: 0.4147, Val Acc: 0.7895
Epoch 101/300, Train Loss: 0.4828, Train Acc: 0.7529, Val Loss: 0.4158, Val Acc: 0.7895
Epoch 102/300, Train Loss: 0.4853, Train Acc: 0.7578, Val Loss: 0.4163, Val Acc: 0.7895
Epoch 103/300, Train Loss: 0.4844, Train Acc: 0.7628, Val Loss: 0.4170, Val Acc: 0.7829
Epoch 104/300, Train Loss: 0.4896, Train Acc: 0.7578, Val Loss: 0.4147, Val Acc: 0.7895
Epoch 105/300, Train Loss: 0.4853, Train Acc: 0.7562, Val Loss: 0.4162, Val Acc: 0.7895
Epoch 106/300, Train Loss: 0.4846, Train Acc: 0.7529, Val Loss: 0.4152, Val Acc: 0.7895
Epoch 107/300, Train Loss: 0.4832, Train Acc: 0.7562, Val Loss: 0.4159, Val Acc: 0.7829
Epoch 108/300, Train Loss: 0.4911, Train Acc: 0.7496, Val Loss: 0.4157, Val Acc: 0.7895
Epoch 109/300, Train Loss: 0.4808, Train Acc: 0.7496, Val Loss: 0.4163, Val Acc: 0.7829
Epoch 110/300, Train Loss: 0.4901, Train Acc: 0.7496, Val Loss: 0.4169, Val Acc: 0.7829
Epoch 111/300, Train Loss: 0.4832, Train Acc: 0.7529, Val Loss: 0.4154, Val Acc: 0.7829
Epoch 112/300, Train Loss: 0.4860, Train Acc: 0.7545, Val Loss: 0.4162, Val Acc: 0.7829
Epoch 113/300, Train Loss: 0.4828, Train Acc: 0.7611, Val Loss: 0.4156, Val Acc: 0.7829
Epoch 114/300, Train Loss: 0.4889, Train Acc: 0.7496, Val Loss: 0.4161, Val Acc: 0.7829
Epoch 115/300, Train Loss: 0.4863, Train Acc: 0.7496, Val Loss: 0.4150, Val Acc: 0.7829
Epoch 116/300, Train Loss: 0.4822, Train Acc: 0.7529, Val Loss: 0.4145, Val Acc: 0.7895
Epoch 117/300, Train Loss: 0.4790, Train Acc: 0.7562, Val Loss: 0.4148, Val Acc: 0.7829
Epoch 118/300, Train Loss: 0.4818, Train Acc: 0.7578, Val Loss: 0.4140, Val Acc: 0.7895
Epoch 119/300, Train Loss: 0.4840, Train Acc: 0.7529, Val Loss: 0.4152, Val Acc: 0.7829
Epoch 120/300, Train Loss: 0.4824, Train Acc: 0.7562, Val Loss: 0.4129, Val Acc: 0.7895
Epoch 121/300, Train Loss: 0.4890, Train Acc: 0.7512, Val Loss: 0.4136, Val Acc: 0.7895
Epoch 122/300, Train Loss: 0.4800, Train Acc: 0.7578, Val Loss: 0.4153, Val Acc: 0.7829
Epoch 123/300, Train Loss: 0.4896, Train Acc: 0.7562, Val Loss: 0.4158, Val Acc: 0.7829
Epoch 124/300, Train Loss: 0.4854, Train Acc: 0.7479, Val Loss: 0.4172, Val Acc: 0.7763
Epoch 125/300, Train Loss: 0.4822, Train Acc: 0.7578, Val Loss: 0.4158, Val Acc: 0.7829
Epoch 126/300, Train Loss: 0.4803, Train Acc: 0.7595, Val Loss: 0.4135, Val Acc: 0.7895
Epoch 127/300, Train Loss: 0.4859, Train Acc: 0.7496, Val Loss: 0.4142, Val Acc: 0.7829
Epoch 128/300, Train Loss: 0.4883, Train Acc: 0.7529, Val Loss: 0.4159, Val Acc: 0.7763
Epoch 129/300, Train Loss: 0.4854, Train Acc: 0.7545, Val Loss: 0.4165, Val Acc: 0.7763
Epoch 130/300, Train Loss: 0.4857, Train Acc: 0.7545, Val Loss: 0.4152, Val Acc: 0.7829
Epoch 131/300, Train Loss: 0.4758, Train Acc: 0.7562, Val Loss: 0.4143, Val Acc: 0.7829
Epoch 132/300, Train Loss: 0.4886, Train Acc: 0.7512, Val Loss: 0.4153, Val Acc: 0.7763
Epoch 133/300, Train Loss: 0.4854, Train Acc: 0.7463, Val Loss: 0.4144, Val Acc: 0.7829
Epoch 134/300, Train Loss: 0.4834, Train Acc: 0.7595, Val Loss: 0.4149, Val Acc: 0.7763
Epoch 135/300, Train Loss: 0.4779, Train Acc: 0.7545, Val Loss: 0.4147, Val Acc: 0.7763
Epoch 136/300, Train Loss: 0.4836, Train Acc: 0.7496, Val Loss: 0.4149, Val Acc: 0.7763
Epoch 137/300, Train Loss: 0.4798, Train Acc: 0.7562, Val Loss: 0.4140, Val Acc: 0.7829
Epoch 138/300, Train Loss: 0.4856, Train Acc: 0.7529, Val Loss: 0.4143, Val Acc: 0.7763
Epoch 139/300, Train Loss: 0.4842, Train Acc: 0.7611, Val Loss: 0.4138, Val Acc: 0.7829
Epoch 140/300, Train Loss: 0.4772, Train Acc: 0.7578, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 141/300, Train Loss: 0.4861, Train Acc: 0.7496, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 142/300, Train Loss: 0.4779, Train Acc: 0.7578, Val Loss: 0.4154, Val Acc: 0.7763
Epoch 143/300, Train Loss: 0.4779, Train Acc: 0.7512, Val Loss: 0.4146, Val Acc: 0.7763
Epoch 144/300, Train Loss: 0.4829, Train Acc: 0.7644, Val Loss: 0.4138, Val Acc: 0.7763
Epoch 145/300, Train Loss: 0.4801, Train Acc: 0.7628, Val Loss: 0.4136, Val Acc: 0.7763
Epoch 146/300, Train Loss: 0.4842, Train Acc: 0.7496, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 147/300, Train Loss: 0.4845, Train Acc: 0.7529, Val Loss: 0.4146, Val Acc: 0.7763
Epoch 148/300, Train Loss: 0.4775, Train Acc: 0.7595, Val Loss: 0.4144, Val Acc: 0.7763
Epoch 149/300, Train Loss: 0.4805, Train Acc: 0.7446, Val Loss: 0.4130, Val Acc: 0.7895
Epoch 150/300, Train Loss: 0.4838, Train Acc: 0.7562, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 151/300, Train Loss: 0.4900, Train Acc: 0.7562, Val Loss: 0.4151, Val Acc: 0.7763
Epoch 152/300, Train Loss: 0.4791, Train Acc: 0.7463, Val Loss: 0.4141, Val Acc: 0.7763
Epoch 153/300, Train Loss: 0.4792, Train Acc: 0.7545, Val Loss: 0.4147, Val Acc: 0.7763
Epoch 154/300, Train Loss: 0.4814, Train Acc: 0.7512, Val Loss: 0.4152, Val Acc: 0.7763
Epoch 155/300, Train Loss: 0.4736, Train Acc: 0.7529, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 156/300, Train Loss: 0.4852, Train Acc: 0.7611, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 157/300, Train Loss: 0.4828, Train Acc: 0.7595, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 158/300, Train Loss: 0.4798, Train Acc: 0.7545, Val Loss: 0.4143, Val Acc: 0.7763
Epoch 159/300, Train Loss: 0.4832, Train Acc: 0.7512, Val Loss: 0.4150, Val Acc: 0.7763
Epoch 160/300, Train Loss: 0.4789, Train Acc: 0.7512, Val Loss: 0.4150, Val Acc: 0.7763
Epoch 161/300, Train Loss: 0.4806, Train Acc: 0.7479, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 162/300, Train Loss: 0.4835, Train Acc: 0.7595, Val Loss: 0.4140, Val Acc: 0.7763
Epoch 163/300, Train Loss: 0.4796, Train Acc: 0.7479, Val Loss: 0.4143, Val Acc: 0.7763
Epoch 164/300, Train Loss: 0.4821, Train Acc: 0.7529, Val Loss: 0.4158, Val Acc: 0.7697
Epoch 165/300, Train Loss: 0.4828, Train Acc: 0.7545, Val Loss: 0.4133, Val Acc: 0.7829
Epoch 166/300, Train Loss: 0.4878, Train Acc: 0.7512, Val Loss: 0.4140, Val Acc: 0.7763
Epoch 167/300, Train Loss: 0.4854, Train Acc: 0.7463, Val Loss: 0.4167, Val Acc: 0.7697
Epoch 168/300, Train Loss: 0.4875, Train Acc: 0.7479, Val Loss: 0.4152, Val Acc: 0.7763
Epoch 169/300, Train Loss: 0.4864, Train Acc: 0.7479, Val Loss: 0.4150, Val Acc: 0.7763
Epoch 170/300, Train Loss: 0.4763, Train Acc: 0.7529, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 171/300, Train Loss: 0.4843, Train Acc: 0.7446, Val Loss: 0.4154, Val Acc: 0.7763
Epoch 172/300, Train Loss: 0.4769, Train Acc: 0.7545, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 173/300, Train Loss: 0.4846, Train Acc: 0.7595, Val Loss: 0.4155, Val Acc: 0.7697
Epoch 174/300, Train Loss: 0.4831, Train Acc: 0.7512, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 175/300, Train Loss: 0.4922, Train Acc: 0.7496, Val Loss: 0.4144, Val Acc: 0.7763
Epoch 176/300, Train Loss: 0.4826, Train Acc: 0.7479, Val Loss: 0.4161, Val Acc: 0.7697
Epoch 177/300, Train Loss: 0.4793, Train Acc: 0.7611, Val Loss: 0.4141, Val Acc: 0.7763
Epoch 178/300, Train Loss: 0.4768, Train Acc: 0.7644, Val Loss: 0.4134, Val Acc: 0.7829
Epoch 179/300, Train Loss: 0.4837, Train Acc: 0.7562, Val Loss: 0.4141, Val Acc: 0.7763
Epoch 180/300, Train Loss: 0.4831, Train Acc: 0.7496, Val Loss: 0.4136, Val Acc: 0.7829
Epoch 181/300, Train Loss: 0.4824, Train Acc: 0.7562, Val Loss: 0.4141, Val Acc: 0.7763
Epoch 182/300, Train Loss: 0.4786, Train Acc: 0.7562, Val Loss: 0.4133, Val Acc: 0.7829
Epoch 183/300, Train Loss: 0.4826, Train Acc: 0.7628, Val Loss: 0.4139, Val Acc: 0.7829
Epoch 184/300, Train Loss: 0.4858, Train Acc: 0.7545, Val Loss: 0.4160, Val Acc: 0.7697
Epoch 185/300, Train Loss: 0.4847, Train Acc: 0.7529, Val Loss: 0.4142, Val Acc: 0.7829
Epoch 186/300, Train Loss: 0.4776, Train Acc: 0.7496, Val Loss: 0.4148, Val Acc: 0.7763
Epoch 187/300, Train Loss: 0.4846, Train Acc: 0.7562, Val Loss: 0.4147, Val Acc: 0.7763
Epoch 188/300, Train Loss: 0.4744, Train Acc: 0.7529, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 189/300, Train Loss: 0.4845, Train Acc: 0.7545, Val Loss: 0.4147, Val Acc: 0.7763
Epoch 190/300, Train Loss: 0.4802, Train Acc: 0.7512, Val Loss: 0.4143, Val Acc: 0.7763
Epoch 191/300, Train Loss: 0.4831, Train Acc: 0.7496, Val Loss: 0.4149, Val Acc: 0.7697
Epoch 192/300, Train Loss: 0.4767, Train Acc: 0.7496, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 193/300, Train Loss: 0.4786, Train Acc: 0.7512, Val Loss: 0.4139, Val Acc: 0.7829
Epoch 194/300, Train Loss: 0.4810, Train Acc: 0.7545, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 195/300, Train Loss: 0.4748, Train Acc: 0.7512, Val Loss: 0.4131, Val Acc: 0.7829
Epoch 196/300, Train Loss: 0.4750, Train Acc: 0.7479, Val Loss: 0.4131, Val Acc: 0.7829
Epoch 197/300, Train Loss: 0.4801, Train Acc: 0.7512, Val Loss: 0.4131, Val Acc: 0.7829
Epoch 198/300, Train Loss: 0.4819, Train Acc: 0.7545, Val Loss: 0.4129, Val Acc: 0.7829
Epoch 199/300, Train Loss: 0.4840, Train Acc: 0.7496, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 200/300, Train Loss: 0.4786, Train Acc: 0.7611, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 201/300, Train Loss: 0.4813, Train Acc: 0.7446, Val Loss: 0.4157, Val Acc: 0.7697
Epoch 202/300, Train Loss: 0.4858, Train Acc: 0.7545, Val Loss: 0.4145, Val Acc: 0.7697
Epoch 203/300, Train Loss: 0.4814, Train Acc: 0.7562, Val Loss: 0.4150, Val Acc: 0.7697
Epoch 204/300, Train Loss: 0.4797, Train Acc: 0.7611, Val Loss: 0.4134, Val Acc: 0.7829
Epoch 205/300, Train Loss: 0.4863, Train Acc: 0.7628, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 206/300, Train Loss: 0.4813, Train Acc: 0.7545, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 207/300, Train Loss: 0.4817, Train Acc: 0.7545, Val Loss: 0.4138, Val Acc: 0.7829
Epoch 208/300, Train Loss: 0.4877, Train Acc: 0.7661, Val Loss: 0.4140, Val Acc: 0.7829
Epoch 209/300, Train Loss: 0.4787, Train Acc: 0.7578, Val Loss: 0.4143, Val Acc: 0.7763
Epoch 210/300, Train Loss: 0.4836, Train Acc: 0.7430, Val Loss: 0.4145, Val Acc: 0.7697
Epoch 211/300, Train Loss: 0.4743, Train Acc: 0.7578, Val Loss: 0.4139, Val Acc: 0.7829
Epoch 212/300, Train Loss: 0.4795, Train Acc: 0.7529, Val Loss: 0.4141, Val Acc: 0.7829
Epoch 213/300, Train Loss: 0.4821, Train Acc: 0.7512, Val Loss: 0.4139, Val Acc: 0.7829
Epoch 214/300, Train Loss: 0.4805, Train Acc: 0.7545, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 215/300, Train Loss: 0.4807, Train Acc: 0.7529, Val Loss: 0.4150, Val Acc: 0.7697
Epoch 216/300, Train Loss: 0.4793, Train Acc: 0.7578, Val Loss: 0.4134, Val Acc: 0.7829
Epoch 217/300, Train Loss: 0.4816, Train Acc: 0.7479, Val Loss: 0.4148, Val Acc: 0.7697
Epoch 218/300, Train Loss: 0.4831, Train Acc: 0.7479, Val Loss: 0.4124, Val Acc: 0.7829
Epoch 219/300, Train Loss: 0.4714, Train Acc: 0.7529, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 220/300, Train Loss: 0.4795, Train Acc: 0.7479, Val Loss: 0.4130, Val Acc: 0.7829
Epoch 221/300, Train Loss: 0.4822, Train Acc: 0.7611, Val Loss: 0.4124, Val Acc: 0.7829
Epoch 222/300, Train Loss: 0.4892, Train Acc: 0.7529, Val Loss: 0.4135, Val Acc: 0.7829
Epoch 223/300, Train Loss: 0.4810, Train Acc: 0.7595, Val Loss: 0.4133, Val Acc: 0.7829
Epoch 224/300, Train Loss: 0.4809, Train Acc: 0.7529, Val Loss: 0.4146, Val Acc: 0.7763
Epoch 225/300, Train Loss: 0.4789, Train Acc: 0.7512, Val Loss: 0.4135, Val Acc: 0.7763
Epoch 226/300, Train Loss: 0.4805, Train Acc: 0.7545, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 227/300, Train Loss: 0.4748, Train Acc: 0.7512, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 228/300, Train Loss: 0.4811, Train Acc: 0.7529, Val Loss: 0.4140, Val Acc: 0.7763
Epoch 229/300, Train Loss: 0.4780, Train Acc: 0.7562, Val Loss: 0.4131, Val Acc: 0.7829
Epoch 230/300, Train Loss: 0.4851, Train Acc: 0.7595, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 231/300, Train Loss: 0.4823, Train Acc: 0.7479, Val Loss: 0.4130, Val Acc: 0.7829
Epoch 232/300, Train Loss: 0.4782, Train Acc: 0.7512, Val Loss: 0.4135, Val Acc: 0.7829
Epoch 233/300, Train Loss: 0.4785, Train Acc: 0.7512, Val Loss: 0.4140, Val Acc: 0.7763
Epoch 234/300, Train Loss: 0.4799, Train Acc: 0.7578, Val Loss: 0.4150, Val Acc: 0.7697
Epoch 235/300, Train Loss: 0.4798, Train Acc: 0.7545, Val Loss: 0.4138, Val Acc: 0.7763
Epoch 236/300, Train Loss: 0.4818, Train Acc: 0.7529, Val Loss: 0.4151, Val Acc: 0.7697
Epoch 237/300, Train Loss: 0.4784, Train Acc: 0.7562, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 238/300, Train Loss: 0.4760, Train Acc: 0.7529, Val Loss: 0.4119, Val Acc: 0.7829
Epoch 239/300, Train Loss: 0.4781, Train Acc: 0.7529, Val Loss: 0.4118, Val Acc: 0.7829
Epoch 240/300, Train Loss: 0.4797, Train Acc: 0.7545, Val Loss: 0.4120, Val Acc: 0.7829
Epoch 241/300, Train Loss: 0.4793, Train Acc: 0.7578, Val Loss: 0.4133, Val Acc: 0.7829
Epoch 242/300, Train Loss: 0.4825, Train Acc: 0.7545, Val Loss: 0.4136, Val Acc: 0.7763
Epoch 243/300, Train Loss: 0.4781, Train Acc: 0.7479, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 244/300, Train Loss: 0.4802, Train Acc: 0.7512, Val Loss: 0.4133, Val Acc: 0.7763
Epoch 245/300, Train Loss: 0.4830, Train Acc: 0.7479, Val Loss: 0.4124, Val Acc: 0.7829
Epoch 246/300, Train Loss: 0.4844, Train Acc: 0.7578, Val Loss: 0.4135, Val Acc: 0.7763
Epoch 247/300, Train Loss: 0.4757, Train Acc: 0.7496, Val Loss: 0.4128, Val Acc: 0.7829
Epoch 248/300, Train Loss: 0.4774, Train Acc: 0.7611, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 249/300, Train Loss: 0.4850, Train Acc: 0.7479, Val Loss: 0.4132, Val Acc: 0.7829
Epoch 250/300, Train Loss: 0.4811, Train Acc: 0.7479, Val Loss: 0.4131, Val Acc: 0.7829
Epoch 251/300, Train Loss: 0.4812, Train Acc: 0.7545, Val Loss: 0.4137, Val Acc: 0.7763
Epoch 252/300, Train Loss: 0.4827, Train Acc: 0.7512, Val Loss: 0.4146, Val Acc: 0.7763
Epoch 253/300, Train Loss: 0.4768, Train Acc: 0.7578, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 254/300, Train Loss: 0.4792, Train Acc: 0.7644, Val Loss: 0.4130, Val Acc: 0.7829
Epoch 255/300, Train Loss: 0.4812, Train Acc: 0.7545, Val Loss: 0.4134, Val Acc: 0.7763
Epoch 256/300, Train Loss: 0.4768, Train Acc: 0.7529, Val Loss: 0.4131, Val Acc: 0.7763
Epoch 257/300, Train Loss: 0.4785, Train Acc: 0.7595, Val Loss: 0.4128, Val Acc: 0.7829
Epoch 258/300, Train Loss: 0.4817, Train Acc: 0.7578, Val Loss: 0.4140, Val Acc: 0.7763
Epoch 259/300, Train Loss: 0.4809, Train Acc: 0.7512, Val Loss: 0.4139, Val Acc: 0.7763
Epoch 260/300, Train Loss: 0.4777, Train Acc: 0.7529, Val Loss: 0.4144, Val Acc: 0.7763
Epoch 261/300, Train Loss: 0.4823, Train Acc: 0.7479, Val Loss: 0.4136, Val Acc: 0.7763
Epoch 262/300, Train Loss: 0.4783, Train Acc: 0.7578, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 263/300, Train Loss: 0.4813, Train Acc: 0.7512, Val Loss: 0.4143, Val Acc: 0.7763
Epoch 264/300, Train Loss: 0.4797, Train Acc: 0.7611, Val Loss: 0.4144, Val Acc: 0.7763
Epoch 265/300, Train Loss: 0.4751, Train Acc: 0.7562, Val Loss: 0.4142, Val Acc: 0.7763
Epoch 266/300, Train Loss: 0.4771, Train Acc: 0.7545, Val Loss: 0.4139, Val Acc: 0.7763
Epoch 267/300, Train Loss: 0.4809, Train Acc: 0.7512, Val Loss: 0.4136, Val Acc: 0.7763
Epoch 268/300, Train Loss: 0.4726, Train Acc: 0.7562, Val Loss: 0.4129, Val Acc: 0.7829
Epoch 269/300, Train Loss: 0.4746, Train Acc: 0.7529, Val Loss: 0.4137, Val Acc: 0.7763
Epoch 270/300, Train Loss: 0.4800, Train Acc: 0.7463, Val Loss: 0.4129, Val Acc: 0.7763
Epoch 271/300, Train Loss: 0.4810, Train Acc: 0.7562, Val Loss: 0.4129, Val Acc: 0.7763
Epoch 272/300, Train Loss: 0.4780, Train Acc: 0.7479, Val Loss: 0.4130, Val Acc: 0.7829
Epoch 273/300, Train Loss: 0.4790, Train Acc: 0.7562, Val Loss: 0.4129, Val Acc: 0.7763
Epoch 274/300, Train Loss: 0.4825, Train Acc: 0.7545, Val Loss: 0.4129, Val Acc: 0.7763
Epoch 275/300, Train Loss: 0.4742, Train Acc: 0.7512, Val Loss: 0.4136, Val Acc: 0.7763
Epoch 276/300, Train Loss: 0.4882, Train Acc: 0.7381, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 277/300, Train Loss: 0.4838, Train Acc: 0.7562, Val Loss: 0.4151, Val Acc: 0.7697
Epoch 278/300, Train Loss: 0.4779, Train Acc: 0.7529, Val Loss: 0.4145, Val Acc: 0.7763
Epoch 279/300, Train Loss: 0.4826, Train Acc: 0.7529, Val Loss: 0.4146, Val Acc: 0.7763
Epoch 280/300, Train Loss: 0.4847, Train Acc: 0.7529, Val Loss: 0.4153, Val Acc: 0.7697
Epoch 281/300, Train Loss: 0.4811, Train Acc: 0.7545, Val Loss: 0.4163, Val Acc: 0.7697
Epoch 282/300, Train Loss: 0.4767, Train Acc: 0.7545, Val Loss: 0.4149, Val Acc: 0.7697
Epoch 283/300, Train Loss: 0.4808, Train Acc: 0.7512, Val Loss: 0.4146, Val Acc: 0.7763
Epoch 284/300, Train Loss: 0.4775, Train Acc: 0.7578, Val Loss: 0.4131, Val Acc: 0.7763
Epoch 285/300, Train Loss: 0.4809, Train Acc: 0.7545, Val Loss: 0.4139, Val Acc: 0.7763
Epoch 286/300, Train Loss: 0.4757, Train Acc: 0.7529, Val Loss: 0.4149, Val Acc: 0.7763
Epoch 287/300, Train Loss: 0.4794, Train Acc: 0.7545, Val Loss: 0.4127, Val Acc: 0.7763
Epoch 288/300, Train Loss: 0.4823, Train Acc: 0.7479, Val Loss: 0.4144, Val Acc: 0.7763
Epoch 289/300, Train Loss: 0.4789, Train Acc: 0.7578, Val Loss: 0.4140, Val Acc: 0.7763
Epoch 290/300, Train Loss: 0.4755, Train Acc: 0.7430, Val Loss: 0.4138, Val Acc: 0.7763
Epoch 291/300, Train Loss: 0.4809, Train Acc: 0.7463, Val Loss: 0.4131, Val Acc: 0.7763
Epoch 292/300, Train Loss: 0.4834, Train Acc: 0.7595, Val Loss: 0.4136, Val Acc: 0.7763
Epoch 293/300, Train Loss: 0.4812, Train Acc: 0.7479, Val Loss: 0.4141, Val Acc: 0.7763
Epoch 294/300, Train Loss: 0.4816, Train Acc: 0.7529, Val Loss: 0.4136, Val Acc: 0.7763
Epoch 295/300, Train Loss: 0.4773, Train Acc: 0.7545, Val Loss: 0.4130, Val Acc: 0.7763
Epoch 296/300, Train Loss: 0.4759, Train Acc: 0.7430, Val Loss: 0.4139, Val Acc: 0.7763
Epoch 297/300, Train Loss: 0.4806, Train Acc: 0.7545, Val Loss: 0.4131, Val Acc: 0.7763
Epoch 298/300, Train Loss: 0.4826, Train Acc: 0.7578, Val Loss: 0.4141, Val Acc: 0.7763
Epoch 299/300, Train Loss: 0.4713, Train Acc: 0.7595, Val Loss: 0.4143, Val Acc: 0.7763
Epoch 300/300, Train Loss: 0.4777, Train Acc: 0.7545, Val Loss: 0.4131, Val Acc: 0.7763
Test Accuracy: 0.7763

曲线图:

代码:

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt# 加载数据
data = pd.read_csv('diabetes.csv', header=None)
X = data.iloc[:, :-1].values
y = data.iloc[:, -1].values# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 转换为PyTorch张量 - 注意这里不需要unsqueeze(2)
X_train = torch.FloatTensor(X_train)  # 形状为 (样本数, 8)
X_test = torch.FloatTensor(X_test)  # 形状为 (样本数, 8)
y_train = torch.FloatTensor(y_train)
y_test = torch.FloatTensor(y_test)# 创建DataLoader
train_data = TensorDataset(X_train, y_train)
test_data = TensorDataset(X_test, y_test)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32)# 定义LSTM模型
class LSTMModel(nn.Module):def __init__(self, input_size=8, hidden_size1=64, hidden_size2=32):super(LSTMModel, self).__init__()self.lstm1 = nn.LSTM(input_size, hidden_size1, batch_first=True)self.dropout1 = nn.Dropout(0.3)self.lstm2 = nn.LSTM(hidden_size1, hidden_size2, batch_first=True)self.dropout2 = nn.Dropout(0.3)self.fc = nn.Linear(hidden_size2, 1)def forward(self, x):# 添加序列长度维度 (batch_size, 1, input_size)x = x.unsqueeze(1)  # 从(batch_size, 8)变为(batch_size, 1, 8)# 第一层LSTMx, _ = self.lstm1(x)x = self.dropout1(x)# 第二层LSTMx, (hn, cn) = self.lstm2(x)x = self.dropout2(hn[-1])  # 取最后一个时间步的隐藏状态x = self.fc(x)return torch.sigmoid(x.squeeze())# 初始化模型
model = LSTMModel(input_size=8)  # 8个特征
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0002)# 训练和验证记录
train_losses = []
train_accs = []
val_losses = []
val_accs = []# 训练模型
epochs = 300
for epoch in range(epochs):model.train()running_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()predicted = (outputs > 0.5).float()total += labels.size(0)correct += (predicted == labels).sum().item()train_loss = running_loss / len(train_loader)train_acc = correct / totaltrain_losses.append(train_loss)train_accs.append(train_acc)# 验证model.eval()val_loss = 0.0correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)loss = criterion(outputs, labels)val_loss += loss.item()predicted = (outputs > 0.5).float()total += labels.size(0)correct += (predicted == labels).sum().item()val_loss = val_loss / len(test_loader)val_acc = correct / totalval_losses.append(val_loss)val_accs.append(val_acc)print(f'Epoch {epoch + 1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')# 评估模型
model.eval()
with torch.no_grad():outputs = model(X_test)predicted = (outputs > 0.5).float()accuracy = (predicted == y_test).float().mean()
print(f'Test Accuracy: {accuracy:.4f}')# 绘制训练曲线
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_accs, label='Training Accuracy')
plt.plot(val_accs, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

下载:基于LSTM实现的糖尿病分类项目资源-CSDN文库

相关文章:

  • 【网工第6版】第5章 网络互联⑦
  • 软考:数值转换知识点详解
  • Phthon
  • 【Linux】基于阻塞队列的生产消费者模型
  • 火语言RPA--发送邮件
  • 树莓派安装GStreamer ,opencv支持, 并在虚拟环境中使用的安装方法
  • opencv--图像变换
  • 使用QML Tumbler 实现时间日期选择器
  • express的中间件,全局中间件,路由中间件,静态资源中间件以及使用注意事项 , 获取请求体数据
  • BOM与DOM(解疑document window关系)
  • 看一看 中间件Middleware
  • JVM性能优化之老年代参数设置
  • 【前端】手写代码输出题易错点汇总
  • git检查提交分支和package.json的version版本是否一致
  • 使用vue2开发一个医疗预约挂号平台-前端静态网站项目练习
  • ASP.NET MVC​ 入门指南
  • JAVA设计模式——(六)装饰模式(Decorator Pattern)
  • 建造者模式:分步构建复杂对象的设计模式
  • 罗伯·派克:Go语言创始者的极客人生
  • 【项目管理】进度网络图 笔记
  • 潘功胜:央行将实施好适度宽松的货币政策,推动中国经济高质量发展
  • 最高法:侵犯著作权罪中的“复制发行”不包括单纯发行行为
  • 中国专家组赴缅开展地震灾害评估工作
  • 长三角与粤港澳大湾区融合发展,无锡何以成为窗口?
  • 蚌埠一动物园用染色犬扮熊猫引争议,园方回应:被投诉已撤走
  • 朱守科已任西藏自治区政府党组成员、自治区公安厅党委书记