CNN+Transformer+SE注意力机制多分类模型 + SHAP特征重要性分析,pytorch框架

发布于:2025-03-30 ⋅ 阅读:(24) ⋅ 点赞:(0)

效果一览

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

代码功能

CNN提取一维序列的局部特征,如光谱峰值、表格数据趋势等。Transformer捕捉一维序列的全局依赖关系,解决长序列建模难题! 弥补CNN在长距离依赖建模上的不足,提升模型的全局特征提取能力。SE注意力机制动态调整特征通道权重,聚焦关键信息,提升分类精度!
支持多类别分类任务,适用于光谱分类、表格数据分类、时间序列分类等场景。
可自定义类别数量
输出训练损失和准确率,并评估训练集和测试集的准确率,精确率,召回率,f1分数,绘制roc曲线,混淆矩阵
结合SHAP(Shapley Additive exPlanations),直观展示每个特征对分类结果的影响!
包括蜂巢图,重要性图,单特征力图,决策图,热图,瀑布图等。

CNN+Transformer+SE注意力机制多分类模型 + SHAP特征重要性分析

模型架构与核心组件

1. CNN(卷积神经网络)

功能

  • 局部特征提取:通过一维卷积核滑动窗口(如核大小=3/5/7),捕获序列中的局部模式(如光谱峰值、数据趋势)。
  • 特征增强:使用多层卷积堆叠(Conv1D+ReLU+MaxPool),逐步抽象高阶特征,输出维度为 (batch_size, channels, seq_len)

实现代码片段

class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding='same')
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(2)
    def forward(self, x):
        return self.pool(self.relu(self.conv(x)))

2. Transformer

功能

  • 全局依赖建模:利用自注意力机制(Multi-Head Attention)捕捉长序列中的上下文关系。
  • 位置编码:添加可学习的位置编码(Positional Encoding),解决序列顺序问题。
  • 特征融合:输出全局特征矩阵 (batch_size, seq_len, d_model)

实现代码片段

class TransformerBlock(nn.Module):
    def __init__(self, d_model, nhead, num_layers=2):
        super().__init__()
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
    def forward(self, x):
        x = self.pos_encoder(x.permute(0,2,1))  # 调整维度为 (seq_len, batch, d_model)
        return self.transformer(x).permute(1,0,2)

3. SE(Squeeze-and-Excitation)注意力机制

功能

  • 通道权重动态调整:通过全局平均池化(Squeeze)和全连接层(Excitation),生成通道权重向量。
  • 特征增强:对CNN或Transformer输出的特征图进行通道级加权,公式:
    ( \text{SE}(x) = x \cdot \sigma(W_2 \cdot \text{ReLU}(W_1 \cdot \text{GAP}(x))) )

实现代码片段

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels//reduction),
            nn.ReLU(),
            nn.Linear(channels//reduction, channels),
            nn.Sigmoid()
        )
    def forward(self, x):
        weights = self.fc(self.gap(x).squeeze(-1))
        return x * weights.unsqueeze(-1)

4. 多分类任务支持

功能

  • 输出层:全连接层 + Softmax,支持自定义类别数(num_classes)。
  • 评估指标:准确率、精确率、召回率、F1分数、ROC-AUC、混淆矩阵。

实现代码片段

# 模型输出层
self.fc = nn.Linear(d_model, num_classes)

# 评估函数
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix
def evaluate(y_true, y_pred):
    print(classification_report(y_true, y_pred))
    print("Confusion Matrix:\n", confusion_matrix(y_true, y_pred))
    print("ROC-AUC:", roc_auc_score(y_true, y_pred, multi_class='ovr'))

5. SHAP特征重要性分析

功能

  • 模型可解释性:基于博弈论的SHAP值,量化特征对分类结果的贡献。
  • 可视化工具:支持蜂巢图、决策图、热力图等,适配序列输入。

实现代码片段

import shap

# 初始化解释器
explainer = shap.DeepExplainer(model, background_data)

# 计算SHAP值
shap_values = explainer.shap_values(test_sample)

# 可视化(示例:特征重要性图)
shap.summary_plot(shap_values, test_sample, plot_type='bar')

模型整合与训练流程

完整模型架构

class CNNTransformerSE(nn.Module):
    def __init__(self, input_dim, num_classes, d_model=64, nhead=4):
        super().__init__()
        self.cnn = CNNBlock(input_dim, 32)
        self.se1 = SEBlock(32)
        self.transformer = TransformerBlock(d_model, nhead)
        self.se2 = SEBlock(d_model)
        self.fc = nn.Linear(d_model, num_classes)
    
    def forward(self, x):
        x = self.cnn(x)          # CNN提取局部特征
        x = self.se1(x)          # SE增强局部特征
        x = self.transformer(x)  # Transformer建模全局依赖
        x = self.se2(x.mean(dim=1))  # SE增强全局特征 + 池化
        return self.fc(x)

训练与评估

# 数据加载(示例:光谱数据集)
from torch.utils.data import DataLoader
train_loader = DataLoader(SpectrumDataset(), batch_size=32, shuffle=True)

# 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(100):
    model.train()
    for x, y in train_loader:
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
    # 每轮评估
    model.eval()
    evaluate(test_labels, model(test_data).argmax(axis=1))

应用场景与优势

  • 适用领域:光谱分类(如化学物质识别)、表格数据分类(如医疗诊断)、时间序列预测(如股票趋势分析)。
  • 优势
    • 局部-全局特征互补:CNN捕捉细节,Transformer建模长依赖,SE优化特征权重。
    • 高可解释性:SHAP分析直观展示关键特征,适用于需要决策透明度的场景(如医疗、金融)。
  • 案例数据集:内置SpectrumDataset示例,支持自定义CSV或NumPy数据输入。

环境依赖

  • 框架:PyTorch ≥1.8.0 + CUDA(可选)
  • 依赖库scikit-learn(评估指标)、shap(可解释性分析)、matplotlib(可视化)