【mmengine】优化器封装(OptimWrapper)(进阶)在执行器(Runner)中配置优化器封装(OptimWrapper)

发布于:2024-10-10 ⋅ 阅读:(15) ⋅ 点赞:(0)

一、 简单配置

  • 以配置一个 SGD 优化器封装为例:
    优化器封装需要接受 optimizer 参数,因此我们首先需要为优化器封装配置 optimizer。MMEngine 会自动将 PyTorch 中的所有优化器都添加进 OPTIMIZERS 注册表中

  • 以配置一个 SGD 优化器封装为例:

optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)
  • 这样我们就配置好了一个优化器类型为 SGD 的优化器封装,学习率、动量等参数如配置所示。考虑到 OptimWrapper 为标准的单精度训练,因此我们也可以不配置 type 字段:
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(optimizer=optimizer)
  • 要想开启混合精度训练和梯度累加,需要将 type 切换成 AmpOptimWrapper,并指定 accumulative_counts 参数
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(type='AmpOptimWrapper', optimizer=optimizer, accumulative_counts=2)

二、 模型中的不同参数设置不同的超参数

PyTorch 的优化器支持对模型中的不同参数设置不同的超参数,例如对一个分类模型的骨干(backbone)和分类头(head)设置不同的学习率:

from torch.optim import SGD
import torch.nn as nn

model = nn.ModuleDict(dict(backbone=nn.Linear(1, 1), head=nn.Linear(1, 1)))

# backbone部分: lr=0.01, momentum=0.9
# head部分: lr=1e-3, momentum=0.8
optimizer = SGD(
    [
    {'params': model.backbone.parameters()},
    {'params': model.head.parameters(), 'lr': 1e-3,'momentum': 0.8}
    ],
    lr=0.01, momentum=0.9)

backbone部分: lr=0.01, momentum=0.9
head部分: lr=1e-3, momentum=0.8

三、 不同类型的参数设置不同的超参系数

例如,我们可以在 paramwise_cfg 中设置 norm_decay_mult=0,从而将正则化层(normalization layer)的权重(weight)和偏置(bias)的权值衰减系数(weight decay)设置为 0
具体示例如下,我们将 ToyModel 中所有正则化层(head.bn)的权重衰减系数设置为 0:

from mmengine.optim import build_optim_wrapper
from collections import OrderedDict

class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.ModuleDict(
            dict(layer0=nn.Linear(1, 1), layer1=nn.Linear(1, 1)))
        self.head = nn.Sequential(
            OrderedDict(
                linear=nn.Linear(1, 1),
                bn=nn.BatchNorm1d(1)))

optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.01, weight_decay=0.0001),
    paramwise_cfg=dict(norm_decay_mult=0))
optimizer = build_optim_wrapper(ToyModel(), optim_wrapper)