特征图可视化代码

发布于:2025-06-01 ⋅ 阅读:(23) ⋅ 点赞:(0)
  • 进行特征图可视化的时候,修改模型的forward函数来进行可视化十分麻烦,还需要想办法把特征图传出来,在模型层层调用的时候更加麻烦,要修改多个无关的嵌套,还容易引起bug。这里提供了一个简单的范式,只需要一个vis.py文件(可从train.py或者test.py修改而来),无需修改模型的定义文件,即可实现特征图的可视化。
  • 该做法的核心思想是两点,第一点是利用vis.py里面的全局变量来存储特征图以及网络层数等,第二点是直接在vis.py里面重写需要可视化特征图的module的forward函数,以用最小的改动将特征图传递出来。
  • 这段代码还提供了利用sns.heatmap可视化特征图的例子,整体代码如下:
# vis talor mod
import argparse
import os
import math
from functools import partial

import yaml
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

import datasets
import models
import utils
from torchvision import transforms
from PIL import Image
import random
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
from models.models_meta import mGAttn
from einops import rearrange as rearrange


@torch.no_grad()
def vis_mod(mod_dict, path, name):
    for i in [3, 13]:#[3,5,19]:
        scale = mod_dict[f'layer_{i}_scale'] 
        offset = mod_dict[f'layer_{i}_offset']
        name_scale_i = name+f'scale_{i}_avg.png'
        vis_feature(scale, path, name_scale_i)
        vis_feature_each_channel(scale, path, name_scale_i)
        name_offset_i = name+f'offset_{i}_avg.png'
        vis_feature(offset, path, name_offset_i)   
        vis_feature_each_channel(offset, path, name_offset_i)   
    return 


def vis_feature(feature, path, name):
    feature = feature[0, ...]
    # lower_percentile = 0.1
    # upper_percentile = 0.9
    # for i in range(feature.shape[0]):
    #     feature_i = feature[i].view(-1)
    #     lower_bound = torch.quantile(feature_i, lower_percentile)
    #     upper_bound = torch.quantile(feature_i, upper_percentile)
    #     feature[i,...] = torch.clamp(feature[i,...], lower_bound, upper_bound)
    # plt.figure()
    plt.figure(figsize=(1.58, 1.58))
    ax = sns.heatmap(torch.mean(feature, dim=0).cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')
    ax.tick_params(axis='both', which='both', length=0)
    plt.tight_layout()
    plt.savefig(os.path.join(path, name))
    plt.close()

    for i in range(feature.size(0)//32):
        plt.figure(figsize=(1.58, 1.58))
        ax = sns.heatmap(torch.mean(feature[i*32:(i+1)*32, :, :], dim=0).cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')
        ax.tick_params(axis='both', which='both', length=0)
        plt.tight_layout()
        plt.savefig(os.path.join(path, name.replace('avg', f'avg_{i}')))
        plt.close()

    

def vis_feature_each_channel(feature, path, name):
    feature = feature[0, ...]
    # lower_percentile = 0.1
    # upper_percentile = 0.9
    # for i in range(feature.shape[0]):
    #     feature_i = feature[i].view(-1)
    #     lower_bound = torch.quantile(feature_i, lower_percentile)
    #     upper_bound = torch.quantile(feature_i, upper_percentile)
    #     feature[i,...] = torch.clamp(feature[i,...], lower_bound, upper_bound)
    for i in range(feature.shape[0]):
        # plt.figure()
        plt.figure(figsize=(1.58, 1.58))
        ax = sns.heatmap(feature[i].cpu().detach().numpy(), cbar=False, annot=False, xticklabels=[], yticklabels=[], cmap='rainbow')
        ax.tick_params(axis='both', which='both', length=0)
        plt.tight_layout()
        plt.savefig(os.path.join(path, name.replace('avg', f'channel_{i}')))
        plt.close()
    



global_feature_maps = {}
def modify_forward_for_mGAttn(module):

    if isinstance(module, mGAttn):
        # original_forward = module.forward

        def modified_forward(self, x):
            """
            x: b * c * h * w
            """
            # 这里省略了模型原有的一些forward过程

            curr_layer = global_feature_maps['curr_layer']
            if curr_layer in [3, 13]:#[1,5,19]:
                B, h, Ch, N = offset.shape
                global_feature_maps[f'layer_{curr_layer}'] = feature.view(B, h, He, We)
            global_feature_maps['curr_layer'] = curr_layer+1

            # 这里省略了模型原有的一些forward过程
            return out

        module.forward =  modified_forward.__get__(module)

    for child_module in module.children():
        modify_forward_for_mGAttn(child_module)




if __name__ == '__main__':
    # 这里省略了一些模型的定义过程
    
    # modify forward for mGAttn
    modify_forward_for_mGAttn(model)
    # 接着按自己的方式直接调用模型即可
    myeval(model)


网站公告

今日签到

点亮在社区的每一天
去签到