【mmsegmentation】Backbone模块(进阶)自定义自己的BACKBONE

发布于:2024-10-09 ⋅ 阅读:(50) ⋅ 点赞:(0)

1、定义自己的backboe

driving\models\backbones\efficientnetlite.py


import math
import torch
import torch.nn as nn
import torch.functional as F
from mmengine.model import BaseModule
from mmseg.models import BACKBONES, build_backbone

efficientnet_lite_params = {
    # width_coefficient, depth_coefficient, image_size, dropout_rate
    'efficientnet_lite0': [1.0, 1.0, 224, 0.2],
    'efficientnet_lite1': [1.0, 1.1, 240, 0.2],
    'efficientnet_lite2': [1.1, 1.2, 260, 0.3],
    'efficientnet_lite3': [1.2, 1.4, 280, 0.3],
    'efficientnet_lite4': [1.4, 1.8, 300, 0.3],
    # 'efficientnet_lite4': [1.4, 1.8, 224, 0.3],
}

def round_filters(filters, multiplier, divisor=8, min_width=None):
    """Calculate and round number of filters based on width multiplier."""
    if not multiplier:
        return filters
    filters *= multiplier
    min_width = min_width or divisor
    new_filters = max(min_width, int(filters + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_filters < 0.9 * filters:
        new_filters += divisor
    return int(new_filters)

def round_repeats(repeats, multiplier):
    """Round number of filters based on depth multiplier."""
    if not multiplier:
        return repeats
    return int(math.ceil(multiplier * repeats))

def drop_connect(x, drop_connect_rate, training):
    if not training:
        return x
    keep_prob = 1.0 - drop_connect_rate
    batch_size = x.shape[0]
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=x.dtype, device=x.device)
    binary_mask = torch.floor(random_tensor)
    x = (x / keep_prob) * binary_mask
    return x

class MBConvBlock(nn.Module):
    def __init__(self, inp, final_oup, k, s, expand_ratio, se_ratio, has_se=False):
        super(MBConvBlock, self).__init__()

        self._momentum = 0.01
        self._epsilon = 1e-3
        self.input_filters = inp
        self.output_filters = final_oup
        self.stride = s
        self.expand_ratio = expand_ratio
        self.has_se = has_se
        self.id_skip = True  # skip connection and drop connect

        # Expansion phase
        oup = inp * expand_ratio  # number of output channels
        if expand_ratio != 1:
            self._expand_conv = nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
            self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._momentum, eps=self._epsilon)

        # Depthwise convolution phase
        self._depthwise_conv = nn.Conv2d(
            in_channels=oup, out_channels=oup, groups=oup,  # groups makes it depthwise
            kernel_size=k, padding=(k - 1) // 2, stride=s, bias=False)
        self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._momentum, eps=self._epsilon)

        # Squeeze and Excitation layer, if desired
        if self.has_se:
            num_squeezed_channels = max(1, int(inp * se_ratio))
            self._se_reduce = nn.Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
            self._se_expand = nn.Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)

        # Output phase
        self._project_conv = nn.Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
        self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._momentum, eps=self._epsilon)
        self._relu = nn.ReLU6(inplace=True)

    def forward(self, x, drop_connect_rate=None):
        """
        :param x: input tensor
        :param drop_connect_rate: drop connect rate (float, between 0 and 1)
        :return: output of block
        """

        # Expansion and Depthwise Convolution
        identity = x
        if self.expand_ratio != 1:
            x = self._relu(self._bn0(self._expand_conv(x)))
        x = self._relu(self._bn1(self._depthwise_conv(x)))

        # Squeeze and Excitation
        if self.has_se:
            x_squeezed = F.adaptive_avg_pool2d(x, 1)
            x_squeezed = self._se_expand(self._relu(self._se_reduce(x_squeezed)))
            x = torch.sigmoid(x_squeezed) * x

        x = self._bn2(self._project_conv(x))

        # Skip connection and drop connect
        if self.id_skip and self.stride == 1  and self.input_filters == self.output_filters:
            if drop_connect_rate:
                x = drop_connect(x, drop_connect_rate, training=self.training)
            x += identity  # skip connection
        return x


# class EfficientNetLite(nn.Module):
@BACKBONES.register_module()
class EfficientNetLite(BaseModule):
    def __init__(
        self, 
        ):
        super(EfficientNetLite, self).__init__()
        print(" EfficientNetLite __init__")

        name = "efficientnet_lite4"
        widthi_multiplier, depth_multiplier, _, drop_connect_rate = efficientnet_lite_params[name]

        # Batch norm parameters
        momentum = 0.01
        epsilon = 1e-3
        self.drop_connect_rate = drop_connect_rate

        mb_block_settings = [
            #repeat|kernal_size|stride|expand|input|output|se_ratio
                [1, 3, 1, 1, 32,  16,  0.25],
                [2, 3, 2, 6, 16,  24,  0.25],
                [2, 5, 2, 6, 24,  40,  0.25],
                [3, 3, 2, 6, 40,  80,  0.25],
                [3, 5, 1, 6, 80,  112, 0.25],
                [4, 5, 2, 6, 112, 192, 0.25],
                [1, 3, 1, 6, 192, 320, 0.25]
            ]

        # Stem
        out_channels = 32
        self.stem = nn.Sequential(
            nn.Conv2d(3, out_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=out_channels, momentum=momentum, eps=epsilon),
            nn.ReLU6(inplace=True),
        )

        # Build blocks
        self.blocks = nn.ModuleList([])
        for i, stage_setting in enumerate(mb_block_settings):
            stage = nn.ModuleList([])
            num_repeat, kernal_size, stride, expand_ratio, input_filters, output_filters, se_ratio = stage_setting
            # Update block input and output filters based on width multiplier.
            input_filters = input_filters if i == 0 else round_filters(input_filters, widthi_multiplier)
            output_filters = round_filters(output_filters, widthi_multiplier)
            num_repeat= num_repeat if i == 0 or i == len(mb_block_settings) - 1  else round_repeats(num_repeat, depth_multiplier)
            

            # The first block needs to take care of stride and filter size increase.
            stage.append(MBConvBlock(input_filters, output_filters, kernal_size, stride, expand_ratio, se_ratio, has_se=False))
            if num_repeat > 1:
                input_filters = output_filters
                stride = 1
            for _ in range(num_repeat - 1):
                stage.append(MBConvBlock(input_filters, output_filters, kernal_size, stride, expand_ratio, se_ratio, has_se=False))
            
            self.blocks.append(stage)


        self._initialize_weights()


        # self.ClsHead = ClassifyHead(num_classes=4)

    def forward(self, x):
        x = self.stem(x)
        idx = 0
        for stage in self.blocks:
            for block in stage:
                drop_connect_rate = self.drop_connect_rate
                if drop_connect_rate:
                    drop_connect_rate *= float(idx) / len(self.blocks)
                x = block(x, drop_connect_rate)
                idx +=1
        # x= self.ClsHead(x)
        print(" EfficientNetLite forward:",x.shape)
        return x
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 1.0/float(n))
                m.bias.data.zero_()
    
    def load_pretrain(self, path):
        state_dict = torch.load(path)
        self.load_state_dict(state_dict, strict=True)

看下BACKBONES注册表(@BACKBONES.register_module())
在这里插入图片描述
可以看到EfficientNetLite被注册到HEADS

  • 其实,这里的BACKBONES 是BACKBONES NECKS HEADS LOSSES SEGMENTORS的总和
from mmseg.registry import MODELS

BACKBONES = MODELS
NECKS = MODELS
HEADS = MODELS
LOSSES = MODELS
SEGMENTORS = MODELS
  • 看下这的BaseModule,mmengine\model\base_module.py
    在这里插入图片描述

2、调用EfficientNetLite

if __name__ == "__main__":
    print("call EfficientNetLite backbone:")
    # 1.配置 dict
    backbone_cfg = dict(
        type='EfficientNetLite',
        )
    # 2.从注册器中构建
    backbone_obj = build_backbone(backbone_cfg)
    backbone_obj.eval()

    # 3.forward 
    name = "efficientnet_lite4"
    wh = efficientnet_lite_params[name][2]
    input = torch.Tensor(2, 3, wh, wh)
    output = backbone_obj(input)

在这里插入图片描述