模型参数加载中遇到的一些问题(BN层)

发布于:2023-02-01 ⋅ 阅读:(615) ⋅ 点赞:(0)

最近需要用到VGG16BN模型,便于需求,自拟了模型,没有使用官方源码,这导致在加载参数过程中遇到一些问题:

1.自拟的模型中BN层出现了一个权重文件中没有的参数 :track_running_stats:

nn.BatchNorm2d(channels, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)

这个参数能够:

①训练时统计forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1;

②如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).

【如下所示】

class _BatchNorm(_NormBase):
    def __init__(
        self,
        num_features,
        eps=1e-5,
        momentum=0.1,  
        ##  如果不为None,会执行track_running_stats+1;如果是None,则执行通过t_r_s计算momentum
        affine=True,
        track_running_stats=True,  
        device=None,
        dtype=None
    ):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(_BatchNorm, self).__init__(
            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
        )

    def forward(self, input: Tensor) -> Tensor:
        self._check_input_dim(input)

        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:         
            if self.num_batches_tracked is not None:  # type: ignore[has-type]
                self.num_batches_tracked = self.num_batches_tracked + 1  
                if self.momentum is None:  
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  
                    exponential_average_factor = self.momentum

2.BN层的参数名存在误导

(自拟的卷积层没有添加bias,导致模型参数出现[weight,weight,bias,...]的结构,开始以为是卷积核设置的问题,后续发现被BN层误导了)

BN层有4组需要更新的参数,分别是running_mean,running_var,weight,bias。

这里的weight,bias会产生误导,实际上weight就是参数gamma,bias就是参数beta

【公式见下】

 

gamma,beta会对规范化后的值分别进行 scale(缩放)和shift(平移)操作。

而running_mean,running_var是当前mini-batch下的均值和方差,每计算一次都会更新。

【补充:涉及到在deployment阶段,设置有BN层的模型需要使用model.eval()语句控制BN层中的running_mean,running_std不更新,因为此阶段图片的喂入是单张的。同理,dropout也会受此机制影响,但dropout会背直接取消】

3.OrderedDict(有序字典,pytorch对模型参数的存储方式)的相关问题:

这个无非就是字典的使用,方法很多。这里给出其中必然会涉及的两个基本思路

①改键,搜键等:将键改为列表。

key_of_paramsdicA = []
for k,v in paramsdicA.items():
    key_of_paramsdicA.append(k)

②改值,迁移值等:内置循环。

dic_c = {(k, j) for k,v in dic_a.items() if k in dic_b}
# # 一般会和update配合使用
dic_b.update(dict_c)

注意:X.items需要用括号,否则报错TypeError: 'builtin_function_or_method' object is not iterable