pytorch版本densenet代码讲解

发布于:2025-07-05 ⋅ 阅读:(24) ⋅ 点赞:(0)

DenseNet 模型代码详解

下面是 DenseNet 模型代码的逐部分详细解析:

1. 导入模块

import re
from collections import OrderedDict
from functools import partial
from typing import Any, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch import Tensor

from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
  • re: 正则表达式模块,用于处理权重名称的转换
  • OrderedDict: 有序字典,用于按顺序构建网络层
  • partial: 创建部分函数,用于预设图像转换参数
  • torch.nn: PyTorch 的神经网络模块
  • torch.utils.checkpoint: 内存优化技术,减少训练时的内存占用
  • ImageClassification: 图像分类的预处理转换
  • register_model: 注册模型的装饰器
  • Weights/WeightsEnum: 预训练权重相关类
  • _IMAGENET_CATEGORIES: ImageNet 数据集类别标签
  • 模型工具函数: 覆盖参数、处理旧版接口等

2. DenseNet 基础层 (_DenseLayer)

class _DenseLayer(nn.Module):
    def __init__(
        self, num_input_features: int, growth_rate: int, bn_size: int, 
        drop_rate: float, memory_efficient: bool = False
    ) -> None:
        super().__init__()
        # 第一个卷积块 (1x1 卷积)
        self.norm1 = nn.BatchNorm2d(num_input_features)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, 
                              kernel_size=1, stride=1, bias=False)
        
        # 第二个卷积块 (3x3 卷积)
        self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, 
                              kernel_size=3, stride=1, padding=1, bias=False)
        
        self.drop_rate = float(drop_rate)
        self.memory_efficient = memory_efficient
  • Bottleneck 结构: 由两个卷积层组成,减少计算量
  • 1x1 卷积: 降维,输出通道数为 bn_size * growth_rate
  • 3x3 卷积: 主卷积层,输出通道数为 growth_rate
  • memory_efficient: 是否使用梯度检查点节省内存

前向传播逻辑

    def bn_function(self, inputs: list[Tensor]) -> Tensor:
        # 拼接所有输入特征
        concated_features = torch.cat(inputs, 1)
        # 通过第一个卷积块
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))
        return bottleneck_output

    def forward(self, input: Tensor) -> Tensor:
        if isinstance(input, Tensor):
            prev_features = [input]
        else:
            prev_features = input
        
        # 内存高效模式处理
        if self.memory_efficient and self.any_requires_grad(prev_features):
            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)
        
        # 通过第二个卷积块
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
        # 应用Dropout
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return new_features
  • 特征拼接: 将前面所有层的输出拼接在一起
  • 梯度检查点: 在内存高效模式下,使用检查点减少内存占用
  • Dropout: 随机丢弃部分神经元,防止过拟合

3. Dense 块 (_DenseBlock)

class _DenseBlock(nn.ModuleDict):
    def __init__(
        self,
        num_layers: int,
        num_input_features: int,
        bn_size: int,
        growth_rate: int,
        drop_rate: float,
        memory_efficient: bool = False,
    ) -> None:
        super().__init__()
        # 创建多个密集层
        for i in range(num_layers):
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
            self.add_module("denselayer%d" % (i + 1), layer)
  • 模块字典: 存储多个密集层
  • 输入特征计算: 每增加一层,输入特征增加 growth_rate 个通道

前向传播

    def forward(self, init_features: Tensor) -> Tensor:
        features = [init_features]
        # 逐层处理并收集输出
        for name, layer in self.items():
            new_features = layer(features)
            features.append(new_features)
        # 拼接所有层的输出
        return torch.cat(features, 1)
  • 特征累积: 每一层的输出都添加到特征列表中
  • 特征拼接: 将所有层的输出沿通道维度拼接

4. 过渡层 (_Transition)

class _Transition(nn.Sequential):
    def __init__(self, num_input_features: int, num_output_features: int) -> None:
        super().__init__()
        # 压缩特征维度
        self.norm = nn.BatchNorm2d(num_input_features)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(num_input_features, num_output_features, 
                             kernel_size=1, stride=1, bias=False)
        # 空间下采样
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
  • 特征压缩: 1x1 卷积减少通道数(通常减半)
  • 空间降维: 平均池化减小特征图尺寸

5. DenseNet 主模型

class DenseNet(nn.Module):
    def __init__(
        self,
        growth_rate: int = 32,
        block_config: tuple[int, int, int, int] = (6, 12, 24, 16),
        num_init_features: int = 64,
        bn_size: int = 4,
        drop_rate: float = 0,
        num_classes: int = 1000,
        memory_efficient: bool = False,
    ) -> None:
        super().__init__()
        _log_api_usage_once(self)  # 记录API使用情况
        
        # 初始卷积层
        self.features = nn.Sequential(
            OrderedDict([
                ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                ("norm0", nn.BatchNorm2d(num_init_features)),
                ("relu0", nn.ReLU(inplace=True)),
                ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
            ])
        )
        
        # 构建多个Dense块和过渡层
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            # 添加Dense块
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
            self.features.add_module("denseblock%d" % (i + 1), block)
            num_features += num_layers * growth_rate
            
            # 添加过渡层(最后一个块除外)
            if i != len(block_config) - 1:
                trans = _Transition(num_features, num_features // 2)
                self.features.add_module("transition%d" % (i + 1), trans)
                num_features = num_features // 2
        
        # 最终批归一化
        self.features.add_module("norm5", nn.BatchNorm2d(num_features))
        
        # 分类器
        self.classifier = nn.Linear(num_features, num_classes)
        
        # 参数初始化
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)
  • 初始卷积层: 快速下采样输入图像
  • 块配置: 控制每个Dense块中的层数
  • 通道管理: 通过过渡层压缩通道数
  • Kaiming初始化: 卷积层的权重初始化
  • 批归一化初始化: 权重设为1,偏置设为0

前向传播

    def forward(self, x: Tensor) -> Tensor:
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))  # 全局平均池化
        out = torch.flatten(out, 1)  # 展平特征
        out = self.classifier(out)  # 分类
        return out
  • 特征提取: 通过多个Dense块和过渡层
  • 全局平均池化: 将特征图转换为特征向量
  • 全连接层: 输出分类结果

6. 权重加载函数

def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:
    # 匹配旧版权重名称模式
    pattern = re.compile(
        r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
    )
    
    state_dict = weights.get_state_dict(progress=progress, check_hash=True)
    # 转换权重名称
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    # 加载权重
    model.load_state_dict(state_dict)
  • 权重名称转换: 适配旧版权重命名方式
  • 哈希校验: 确保下载的权重文件完整无误

7. 模型工厂函数

def _densenet(
    growth_rate: int,
    block_config: tuple[int, int, int, int],
    num_init_features: int,
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> DenseNet:
    # 根据权重调整输出类别数
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
    
    # 创建模型
    model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
    
    # 加载预训练权重
    if weights is not None:
        _load_state_dict(model=model, weights=weights, progress=progress)
    
    return model
  • 参数覆盖: 根据预训练权重调整输出类别数
  • 灵活配置: 支持不同DenseNet变体

8. 预训练权重定义

_COMMON_META = {
    "min_size": (29, 29),  # 最小输入尺寸
    "categories": _IMAGENET_CATEGORIES,  # ImageNet类别
    "recipe": "https://github.com/pytorch/vision/pull/116",  # 训练方法
}

class DenseNet121_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
        transforms=partial(ImageClassification, crop_size=224),  # 图像预处理
        meta={
            **_COMMON_META,
            "num_params": 7978856,  # 参数量
            "_metrics": {  # 性能指标
                "ImageNet-1K": {
                    "acc@1": 74.434,  # top-1准确率
                    "acc@5": 91.972,  # top-5准确率
                }
            },
            "_ops": 2.834,  # 计算量 (GFLOPs)
            "_file_size": 30.845,  # 文件大小 (MB)
        },
    )
    DEFAULT = IMAGENET1K_V1  # 默认权重
  • 权重元数据: 包含模型性能和资源信息
  • 预处理定义: 指定图像分类任务的预处理流程
  • 性能指标: 提供在ImageNet上的评估结果

9. 模型变体实现

@register_model()  # 注册模型
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, 
               progress: bool = True, **kwargs: Any) -> DenseNet:
    weights = DenseNet121_Weights.verify(weights)  # 验证权重
    return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
  • DenseNet121: 增长率32,块配置[6,12,24,16],初始特征64
  • DenseNet169: 增长率32,块配置[6,12,32,32],初始特征64
  • DenseNet201: 增长率32,块配置[6,12,48,32],初始特征64
  • DenseNet161: 增长率48,块配置[6,12,36,24],初始特征96

DenseNet 关键特点

  1. 密集连接: 每一层都接收前面所有层的特征图作为输入
  2. 特征重用: 通过拼接实现多层次特征融合
  3. 瓶颈设计: 1×1卷积减少计算量
  4. 过渡层: 压缩特征维度和空间尺寸
  5. 高效内存: 可选的内存优化模式

DenseNet通过密集连接促进了特征重用,减少了梯度消失问题,提高了参数效率,在各种视觉任务中表现出色。


网站公告

今日签到

点亮在社区的每一天
去签到