day47 注意力热图可视化

发布于:2025-07-02 ⋅ 阅读:(22) ⋅ 点赞:(0)

昨天介绍了特征图的可视化,今天介绍热力图的可视化。

可视化部分同理,在训练完成后通过钩子函数取出权重或梯度,即可进行特征图的可视化,Grad-CAM课可视化、注意力热图可视化

# 可视化空间注意力热力图(显示模型关注的图像区域)
def visualize_attention_map(model, test_loader, device, class_names, num_samples=3):
    """可视化模型的注意力热力图,展示模型关注的图像区域"""
    model.eval()  # 设置为评估模式
    
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            if i >= num_samples:  # 只可视化前几个样本
                break
                
            images, labels = images.to(device), labels.to(device)
            
            # 创建一个钩子,捕获中间特征图
            activation_maps = []
            
            def hook(module, input, output):
                activation_maps.append(output.cpu())
            
            # 为最后一个卷积层注册钩子(获取特征图)
            hook_handle = model.conv3.register_forward_hook(hook)
            
            # 前向传播,触发钩子
            outputs = model(images)
            
            # 移除钩子
            hook_handle.remove()
            
            # 获取预测结果
            _, predicted = torch.max(outputs, 1)
            
            # 获取原始图像
            img = images[0].cpu().permute(1, 2, 0).numpy()
            # 反标准化处理
            img = img * np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, 3) + np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, 3)
            img = np.clip(img, 0, 1)
            
            # 获取激活图(最后一个卷积层的输出)
            feature_map = activation_maps[0][0].cpu()  # 取第一个样本
            
            # 计算通道注意力权重(使用SE模块的全局平均池化)
            channel_weights = torch.mean(feature_map, dim=(1, 2))  # [C]
            
            # 按权重对通道排序
            sorted_indices = torch.argsort(channel_weights, descending=True)
            
            # 创建子图
            fig, axes = plt.subplots(1, 4, figsize=(16, 4))
            
            # 显示原始图像
            axes[0].imshow(img)
            axes[0].set_title(f'原始图像\n真实: {class_names[labels[0]]}\n预测: {class_names[predicted[0]]}')
            axes[0].axis('off')
            
            # 显示前3个最活跃通道的热力图
            for j in range(3):
                channel_idx = sorted_indices[j] 
                # 获取对应通道的特征图
                channel_map = feature_map[channel_idx].numpy()
                # 归一化到[0,1]
                channel_map = (channel_map - channel_map.min()) / (channel_map.max() - channel_map.min() + 1e-8)
                
                # 调整热力图大小以匹配原始图像
                from scipy.ndimage import zoom
                heatmap = zoom(channel_map, (32/feature_map.shape[1], 32/feature_map.shape[2]))
                
                # 显示热力图
                axes[j+1].imshow(img)
                axes[j+1].imshow(heatmap, alpha=0.5, cmap='jet')
                axes[j+1].set_title(f'注意力热力图 - 通道 {channel_idx}')
                axes[j+1].axis('off')
            
            plt.tight_layout()
            plt.show()

# 调用可视化函数
visualize_attention_map(model, test_loader, device, class_names, num_samples=3)

这个注意力热图是通过钩子机制:register_forward_hook捕获最后一个卷积层(conv3)的输出特征图。

        1. 通道权重计算:对特征图的每个通道进行全局平均池化,得到通道重要性权重
        2. 热力图生成:将高权重通道的特征图缩放至原始图像尺寸,与原图叠加显示。

热力图(红色表示高关注,蓝色表示低关注)半透明覆盖在原图上,主要从以下方面理解:

        - 高关注区域(红色):模型任务对分类最重要的区域。例如
                - 在识别“狗”时,热力图可能聚焦狗的面部、身体轮廓或特征性纹理。
                - 若热力图错误聚焦背景(如红色区域在无关物体上),可能表示模型过拟合或训练不足。

多通道对比:

        - 不同通道关注不同特征,例如:
                - 通道1可能关注整体轮廓,通道2关注纹理细节,通道3关注颜色分布。
                - 结合多个通道的热力图,可全面理解模型的决策逻辑。

可以帮助解释

        - 检查模型是否关注正确区域(如识别狗时,是否聚焦狗而非背景)。
        - 发现数据标注问题(如标签错误、图像噪声)。
        - 向非技术人员解释模型决策数据(如“模型认为这是狗,因为关注了眼睛和嘴巴”)。

@浙大疏锦行 


网站公告

今日签到

点亮在社区的每一天
去签到