Pytorch|RNN-心脏病预测

发布于:2025-04-03 ⋅ 阅读:(15) ⋅ 点赞:(0)

- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/rnFa-IeY93EpjVu0yzzjkw) 中的学习记录博客**
- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

一:前期准备工作

1.设置硬件设备

import numpy as np
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
import seaborn as sns

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

2.导入数据

df = pd.read_csv("/content/drive/MyDrive/heart.csv")
df

二:构建数据集

1.标准化

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

X = df.iloc[:,:-1]
y = df.iloc[:,-1]

#将每一列特征标准化为标准正态分布
sc = StandardScaler()
X = sc.fit_transform(X)

2.划分数据集

X = torch.tensor(np.array(X),dtype = torch.float32)
y = torch.tensor(np.array(y),dtype = torch.int64)

X_train,X_test,y_train,y_test = train_test_split(X,y,test_size = 0.1,random_state = 1)
X_train.shape,y_train.shape

3.构建数据加载器

from torch.utils.data import TensorDataset,DataLoader

train_dl = DataLoader(TensorDataset(X_train,y_train),batch_size = 64,shuffle = False)
test_dl = DataLoader(TensorDataset(X_test,y_test),batch_size = 64,shuffle = False)

三:模型训练

1.构建模型

class model_rnn(nn.Module):
  def __init__(self):
    super(model_rnn,self).__init__()
    self.rnn0 = nn.RNN(input_size = 13,hidden_size=200,num_layers=1,batch_first=True)
    self.fc0 = nn.Linear(200,50)
    self.fc1 = nn.Linear(50,2)
  def forward(self,x):
    out,hidden1 = self.rnn0(x)
    out = self.fc0(out)
    out = self.fc1(out)
    return out

model = model_rnn().to(device)
model

model(torch.rand(30,13).to(device)).shape

2.定义训练函数

def train(dataloader,model,loss_fn,optimizer):
  size = len(dataloader.dataset)
  num_batches = len(dataloader)

  train_loss,train_acc = 0,0
  for X,y in dataloader:
    X,y = X.to(device),y.to(device)

    pred = model(X)
    loss = loss_fn(pred,y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
    train_loss += loss.item()

  train_acc /= size
  train_loss /= num_batches
  return train_acc,train_loss

3.定义测试函数

def test(dataloader,model,loss_fn):
  size = len(dataloader.dataset)
  num_batches = len(dataloader)
  test_loss,test_acc = 0,0

  with torch.no_grad():
    for imgs,target in dataloader:
      imgs,target = imgs.to(device), target.to(device)

      target_pred = model(imgs)
      loss = loss_fn(target_pred,target)

      test_loss += loss.item()
      test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

  test_acc /= size
  test_loss /= num_batches
  return test_acc,test_loss

4.正式训练模型

loss_fn = nn.CrossEntropyLoss()
learn_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(),lr = learn_rate)
epochs =50

train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
  model.train()
  epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,optimizer)
  
  model.eval()
  epoch_test_acc,epoch_test_loss = test(test_dl,model,loss_fn)

  train_loss.append(epoch_train_loss)
  train_acc.append(epoch_train_acc)
  test_loss.append(epoch_test_loss)
  test_acc.append(epoch_test_acc)

  lr = optimizer.state_dict()['param_groups'][0]['lr']
  template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},lr:{:.2E}')
  print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))
print("="*20,'Done',"="*20)
Epoch: 1,Train_acc:47.4%,Train_loss:0.698,Test_acc:54.8%,Test_loss:0.680,lr:1.00E-04
Epoch: 2,Train_acc:57.7%,Train_loss:0.682,Test_acc:71.0%,Test_loss:0.663,lr:1.00E-04
Epoch: 3,Train_acc:66.2%,Train_loss:0.668,Test_acc:77.4%,Test_loss:0.646,lr:1.00E-04
Epoch: 4,Train_acc:71.3%,Train_loss:0.654,Test_acc:83.9%,Test_loss:0.629,lr:1.00E-04
Epoch: 5,Train_acc:74.3%,Train_loss:0.640,Test_acc:83.9%,Test_loss:0.614,lr:1.00E-04
Epoch: 6,Train_acc:77.2%,Train_loss:0.626,Test_acc:83.9%,Test_loss:0.598,lr:1.00E-04
Epoch: 7,Train_acc:77.9%,Train_loss:0.613,Test_acc:83.9%,Test_loss:0.583,lr:1.00E-04
Epoch: 8,Train_acc:79.4%,Train_loss:0.599,Test_acc:83.9%,Test_loss:0.568,lr:1.00E-04
Epoch: 9,Train_acc:79.4%,Train_loss:0.586,Test_acc:83.9%,Test_loss:0.554,lr:1.00E-04
Epoch:10,Train_acc:79.8%,Train_loss:0.572,Test_acc:83.9%,Test_loss:0.540,lr:1.00E-04
Epoch:11,Train_acc:80.5%,Train_loss:0.559,Test_acc:83.9%,Test_loss:0.527,lr:1.00E-04
Epoch:12,Train_acc:81.2%,Train_loss:0.545,Test_acc:83.9%,Test_loss:0.515,lr:1.00E-04
Epoch:13,Train_acc:81.6%,Train_loss:0.531,Test_acc:83.9%,Test_loss:0.503,lr:1.00E-04
Epoch:14,Train_acc:80.9%,Train_loss:0.517,Test_acc:83.9%,Test_loss:0.492,lr:1.00E-04
Epoch:15,Train_acc:80.9%,Train_loss:0.503,Test_acc:83.9%,Test_loss:0.483,lr:1.00E-04
Epoch:16,Train_acc:81.6%,Train_loss:0.489,Test_acc:83.9%,Test_loss:0.475,lr:1.00E-04
Epoch:17,Train_acc:81.6%,Train_loss:0.476,Test_acc:83.9%,Test_loss:0.468,lr:1.00E-04
Epoch:18,Train_acc:82.0%,Train_loss:0.463,Test_acc:83.9%,Test_loss:0.462,lr:1.00E-04
Epoch:19,Train_acc:81.6%,Train_loss:0.450,Test_acc:83.9%,Test_loss:0.458,lr:1.00E-04
Epoch:20,Train_acc:81.6%,Train_loss:0.438,Test_acc:83.9%,Test_loss:0.454,lr:1.00E-04
Epoch:21,Train_acc:82.4%,Train_loss:0.426,Test_acc:83.9%,Test_loss:0.450,lr:1.00E-04
Epoch:22,Train_acc:83.1%,Train_loss:0.415,Test_acc:83.9%,Test_loss:0.445,lr:1.00E-04
Epoch:23,Train_acc:83.1%,Train_loss:0.404,Test_acc:83.9%,Test_loss:0.440,lr:1.00E-04
Epoch:24,Train_acc:83.5%,Train_loss:0.394,Test_acc:83.9%,Test_loss:0.435,lr:1.00E-04
Epoch:25,Train_acc:84.2%,Train_loss:0.384,Test_acc:83.9%,Test_loss:0.429,lr:1.00E-04
Epoch:26,Train_acc:84.9%,Train_loss:0.374,Test_acc:83.9%,Test_loss:0.423,lr:1.00E-04
Epoch:27,Train_acc:86.0%,Train_loss:0.364,Test_acc:83.9%,Test_loss:0.418,lr:1.00E-04
Epoch:28,Train_acc:86.4%,Train_loss:0.355,Test_acc:83.9%,Test_loss:0.414,lr:1.00E-04
Epoch:29,Train_acc:86.8%,Train_loss:0.347,Test_acc:80.6%,Test_loss:0.410,lr:1.00E-04
Epoch:30,Train_acc:86.8%,Train_loss:0.338,Test_acc:83.9%,Test_loss:0.408,lr:1.00E-04
Epoch:31,Train_acc:87.1%,Train_loss:0.331,Test_acc:83.9%,Test_loss:0.406,lr:1.00E-04
Epoch:32,Train_acc:87.9%,Train_loss:0.323,Test_acc:83.9%,Test_loss:0.405,lr:1.00E-04
Epoch:33,Train_acc:89.3%,Train_loss:0.315,Test_acc:83.9%,Test_loss:0.404,lr:1.00E-04
Epoch:34,Train_acc:89.7%,Train_loss:0.308,Test_acc:83.9%,Test_loss:0.403,lr:1.00E-04
Epoch:35,Train_acc:89.7%,Train_loss:0.301,Test_acc:83.9%,Test_loss:0.402,lr:1.00E-04
Epoch:36,Train_acc:89.0%,Train_loss:0.294,Test_acc:83.9%,Test_loss:0.402,lr:1.00E-04
Epoch:37,Train_acc:89.3%,Train_loss:0.287,Test_acc:87.1%,Test_loss:0.401,lr:1.00E-04
Epoch:38,Train_acc:89.7%,Train_loss:0.280,Test_acc:87.1%,Test_loss:0.401,lr:1.00E-04
Epoch:39,Train_acc:89.0%,Train_loss:0.274,Test_acc:87.1%,Test_loss:0.402,lr:1.00E-04
Epoch:40,Train_acc:89.3%,Train_loss:0.267,Test_acc:87.1%,Test_loss:0.403,lr:1.00E-04
Epoch:41,Train_acc:89.7%,Train_loss:0.261,Test_acc:87.1%,Test_loss:0.404,lr:1.00E-04
Epoch:42,Train_acc:89.3%,Train_loss:0.255,Test_acc:87.1%,Test_loss:0.405,lr:1.00E-04
Epoch:43,Train_acc:89.3%,Train_loss:0.249,Test_acc:87.1%,Test_loss:0.406,lr:1.00E-04
Epoch:44,Train_acc:90.1%,Train_loss:0.243,Test_acc:87.1%,Test_loss:0.408,lr:1.00E-04
Epoch:45,Train_acc:90.4%,Train_loss:0.237,Test_acc:87.1%,Test_loss:0.410,lr:1.00E-04
Epoch:46,Train_acc:90.4%,Train_loss:0.231,Test_acc:87.1%,Test_loss:0.412,lr:1.00E-04
Epoch:47,Train_acc:90.4%,Train_loss:0.225,Test_acc:87.1%,Test_loss:0.414,lr:1.00E-04
Epoch:48,Train_acc:90.8%,Train_loss:0.220,Test_acc:87.1%,Test_loss:0.416,lr:1.00E-04
Epoch:49,Train_acc:91.2%,Train_loss:0.214,Test_acc:87.1%,Test_loss:0.418,lr:1.00E-04
Epoch:50,Train_acc:92.3%,Train_loss:0.208,Test_acc:83.9%,Test_loss:0.421,lr:1.00E-04
==================== Done ====================

四:模型评估

1.Loss 与 Accuracy 图

from logging import currentframe
import matplotlib.pyplot as plt
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

current_time = datetime.now()
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus']=False
plt.rcParams['figure.dpi'] = 200

epochs_range = range(1,epochs+1)
plt.figure(figsize = (12,3))
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label = 'Train Accuracy')
plt.plot(epochs_range,test_acc,label = 'Test Accuracy')
plt.legend(loc = 'lower right')
plt.title('Training and Testing Accuracy')
plt.xlabel(current_time)
plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label = 'Training Loss')
plt.plot(epochs_range,test_loss,label = 'Test Loss')
plt.legend(loc = 'upper right')
plt.title('Training and Testing Loss')
plt.show()

2.混淆矩阵

print("===========输入数据Shape为===========")
print("X_test.shape:",X_test.shape)
print("y_test.shape:",y_test.shape)

pred = model(X_test.to(device))
pred_labels = pred.argmax(dim=1).cpu().numpy()
print("\n===========输出数据Shape为===========")
print("pred.shape:",pred_labels.shape)

from sklearn.metrics import confusion_matrix,classification_report
cm = confusion_matrix(y_test,pred_labels)
plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm,annot = True,fmt = "d",cmap="Blues")

plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title('Confusion Matrix',fontsize =12)
plt.xlabel('Predicted labels',fontsize=10)
plt.ylabel('True labels',fontsize=10)
plt.tight_layout()
plt.show()

3.调用模型进行预测

test_X = X_test[0].reshape(1,-1)
pred = model(test_X.to(device)).argmax(1).item()
print("预测结果为:",pred)
print("=="*20)
print("0:不会患心脏病")
print("1:可能患心脏病")