深度学习实践:使用 PyTorch 实现简单 CNN 分类医学影像(胸部 X 光)

发布于:2025-08-07 ⋅ 阅读:(29) ⋅ 点赞:(0)

在这里插入图片描述

🧑 博主简介:CSDN博客专家、CSDN平台优质创作者,高级开发工程师,数学专业,10年以上C/C++, C#,Java等多种编程语言开发经验,拥有高级工程师证书;擅长C/C++、C#等开发语言,熟悉Java常用开发技术,能熟练应用常用数据库SQL server,Oracle,mysql,postgresql等进行开发应用,熟悉DICOM医学影像及DICOM协议,业余时间自学JavaScript,Vue,qt,python等,具备多种混合语言开发能力。撰写博客分享知识,致力于帮助编程爱好者共同进步。欢迎关注、交流及合作,提供技术支持与解决方案。\n技术合作请加本人wx(注明来自csdn):xt20160813


使用 PyTorch 实现简单 CNN 分类医学影像(胸部 X 光)

本文基于 Kaggle 胸部 X 光图像数据集,实现一个简单 CNN 模型,用于区分正常和肺炎 X 光图像。我们将涵盖数据预处理、模型构建、训练、评估和结果可视化,主要内容包括: 数据集:包含5,216张训练图像(1,341正常,3,875肺炎),存在类别不平衡问题。 预处理: 图像灰度化并调整为224×224 应用随机翻转、旋转等数据增强 使用标准化处理 模型架构: 3个卷积层(带ReLU和最大池化) 2个全连接层 适合二分类任务 实现特点: 详细注释的代码 包含数据可视化(Chart.js图表) 考虑计算资源优化 适用性:适合初学者和进阶开发者。
在这里插入图片描述


一、任务概述

  • 数据集:Kaggle 胸部 X 光图像数据集,包含约 5,216 张训练图像(1,341 正常,3,875 肺炎)。
  • 任务:二分类,预测 X 光图像是否为肺炎(0: 正常,1: 肺炎)。
  • 模型:简单 CNN,包含 3 个卷积层(带 ReLU 和最大池化)+ 2 个全连接层。
  • 环境:PyTorch, 推荐 GPU 加速。
  • 挑战
    • 类不平衡:肺炎样本占主导。
    • 图像噪声:X 光图像质量差异。
    • 计算资源:需优化模型以适应有限硬件。

二、实现步骤

2.1 环境设置

安装必要的 Python 库:

pip install torch torchvision opencv-python pandas numpy matplotlib seaborn

2.2 数据预处理

CNN 直接处理原始图像,无需手动特征提取。我们使用 torchvision 的数据增强和标准化来提高模型鲁棒性。

import os
import cv2
import numpy as np
from glob import glob
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class ChestXRayDataset(Dataset):
    """
    胸部 X 光图像数据集
    """
    def __init__(self, image_paths, labels, transform=None):
        """
        初始化数据集
        :param image_paths: 图像路径列表
        :param labels: 标签列表 (0: 正常, 1: 肺炎)
        :param transform: 数据增强变换
        """
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # 读取灰度图像
        img = cv2.imread(self.image_paths[idx], cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (224, 224))  # 调整为 224x224
        img = img[:, :, np.newaxis]  # 增加通道维度 [224, 224, 1]
        if self.transform:
            img = self.transform(img)
        label = self.labels[idx]
        return img, label

# 数据增强和标准化
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量
    transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转
    transforms.RandomRotation(10),  # 随机旋转 ±10 度
    transforms.Normalize(mean=[0.5], std=[0.5])  # 灰度图像标准化
])

# 加载数据
data_dir = 'chest_xray/train'  # 替换为实际 Kaggle 数据集路径
normal_paths = glob(os.path.join(data_dir, 'NORMAL', '*.jpeg'))
pneumonia_paths = glob(os.path.join(data_dir, 'PNEUMONIA', '*.jpeg'))
image_paths = normal_paths + pneumonia_paths
labels = [0] * len(normal_paths) + [1] * len(pneumonia_paths)

# 划分训练集和测试集
train_paths, test_paths, train_labels, test_labels = train_test_split(
    image_paths, labels, test_size=0.2, random_state=42, stratify=labels
)

# 创建数据集和数据加载器
train_dataset = ChestXRayDataset(train_paths, train_labels, transform=transform)
test_dataset = ChestXRayDataset(test_paths, test_labels, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

# 数据集统计
print(f'训练集样本数: {
     
     len(train_dataset)}, 测试集样本数: {
     
     len(test_dataset)}')
print(f'正常样本: {
     
     sum(labels == 0)}, 肺炎样本: {
     
     sum(labels == 1)}')

数据预处理流程的文本描述

  • 输入:X 光图像(JPEG 格式,灰度)。
  • 处理
    1. 读取图像,转换为灰度。
    2. 调整尺寸至 224x224。
    3. 增加通道维度,形状为 [224, 224, 1]。
    4. 数据增强:随机翻转、旋转,标准化(均值 0.5,标准差 0.5)。
  • 输出:张量形式的图像和标签,送入 DataLoader 分批处理。
  • 箭头:从图像文件到张量,标注变换步骤(Resize → Transform → Tensor)。

数据集分布可视化
以下 图表展示训练集和测试集的类别分布。
在这里插入图片描述

{
  "type": "bar",
  "data": {
    "labels": ["正常", "肺炎"],
    "datasets": [
      {
        "label": "训练集",
        "data": [1072, 3100],  // 假设 80% 训练集比例
        "backgroundColor": "#1f77b4",
        "borderColor": "#1f77b4",
        "borderWidth": 1
      },
      {
        "label": "测试集",
        "data": [269, 775],  // 假设 20% 测试集比例
        "backgroundColor": "#ff7f0e",
        "borderColor": "#ff7f0e",
        "borderWidth": 1
      }
    ]
  },
  "options": {
    "scales": {
      "y": {
        "beginAtZero": true,
        "title": {
          "display": true,
          "text": "样本数量"
        }
      },
      "x": {
        "title": {
          "display": true,
          "text": "类别"
        }
      }
    },
    "plugins": {
      "title": {
        "display": true,
        "text": "训练集与测试集类别分布"
      }
    }
  }
}

2.3 定义简单 CNN 模型

设计一个简单 CNN,包含 3 个卷积层(带 ReLU 和最大池化)+ 2 个全连接层,适合二分类任务。

import torch.nn as nn

class SimpleCNN(nn.Module):
    """
    简单 CNN 模型,用于 X 光图像二分类
    """
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            # 卷积层 1: 输入 [1, 224, 224] -> 输出 [16, 112, 112]
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 卷积层 2: 输入 [16, 112, 112] -> 输出 [32, 56, 56]
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 卷积层 3: 输入 [32, 56, 56] -> 输出 [64, 28, 28]
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),  # 展平 [64, 28, 28] -> [64*28*28]
            nn.Linear(64 * 28 * 28, 512),
            nn.ReLU(),
            nn.Dropout(0.5),  # 防止过拟合
            nn.Linear(512, 1),
            nn.Sigmoid()  # 二分类输出
        )

网站公告

今日签到

点亮在社区的每一天
去签到