《datawhale2411组队学习 模型压缩技术7:NNI剪枝》

发布于:2024-11-29 ⋅ 阅读:(14) ⋅ 点赞:(0)

一、NNI简介

  NNI(Neural Network Intelligence)是一个开源的自动机器学习(AutoML)工具,由微软亚洲研究院推出。它可以帮助用户自动化地进行超参数调优、神经网络架构搜索、模型压缩和特征工程等任务。NNI 支持多种深度学习框架,如PyTorch、TensorFlow等,并且可以在多种训练平台上运行,包括本地机器、远程服务器、Kubernetes等。NNI主要有以下功能:

NNI剪枝方法 描述 参考论文
Level Pruner 基于权重元素的绝对值,对每个权重元素按指定比例进行剪枝。
L1 Norm Pruner 使用最小 L1 权重范数修剪输出通道 《Pruning Filters for Efficient Convnets》
L2 Norm Pruner 使用最小 L2 权重范数修剪输出通道
FPGM Pruner 通过几何中值进行滤波器剪枝的深度卷积神经网络加速。 《Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration》
Slim Pruner 通过剪除 BN 层中的缩放因子来修剪输出通道 《Learning Efficient Convolutional Networks through Network Slimming》
Taylor FO Weight Pruner 基于权重的一阶泰勒展开计算重要性,对过滤器进行剪枝 《Importance Estimation for Neural Network Pruning》
Linear Pruner 在每轮剪枝中,稀疏率线性增加,并使用基础剪枝方法对模型进行剪枝。
AGP Pruner 自动渐进剪枝 《To prune, or not to prune: exploring the efficacy of pruning for model compression》
Movement Pruner 运动剪枝,通过微调实现自适应稀疏性。 《Movement Pruning: Adaptive Sparsity by Fine-Tuning》
NNI量化方法 描述 参考论文
QAT Quantizer 用于高效整数算术推理的神经网络的量化和训练。 《Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference》
DoReFa Quantizer DoReFa-Net:训练具有低位宽梯度的低位宽卷积神经网络。 《DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients》
BNN Quantizer 二值化神经网络:训练权重和激活限制为 +1 或 -1 《Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1》
LSQ Quantizer 学习步长量化。 《Learned Step Size Quantization》
PTQ Quantizer 训练后量化。
  • 神经网络架构搜索
  • 知识蒸馏
    • DynamicLayerwiseDistiller:每个学生模型蒸馏目标(即学生模型中层的输出)将链接到该蒸馏器中的教师模型蒸馏目标列表。在蒸馏过程中,学生目标将计算与其链接的每个教师目标的蒸馏损失列表,然后选择损失列表中的最小损失作为当前学生目标蒸馏损失。最终蒸馏损失是每个学生目标蒸馏损失乘以 lambda 的总和。最终的训练损失是原始损失乘以 origin_loss_lambda 加上最终的蒸馏损失。
    • Adaptive1dLayerwiseDistiller:该蒸馏器将通过在学生蒸馏目标和教师蒸馏目标之间添加可训练的torch.nn.Linear来自适应地调整学生蒸馏目标和教师蒸馏目标之间的最后一个维度。 (如果学生和教师之间的最后一个维度已经对齐,则不会添加新的线性图层。)

  此外,NNI还支持超参数调优、特征工程和实验管理,更多内容请查看NNI文档。其安装方式也很简单:

pip install nni

二、 NNI剪枝快速入门

参考文档《Pruning Quickstart》

  模型剪枝是一种通过减少模型权重大小或中间状态大小来减少模型大小和计算量的技术。修剪 DNN 模型有以下三种常见做法:

  • 预训练模型 -> 修剪模型 -> 微调修剪后的模型
  • 在训练期间修剪模型(即修剪感知训练)-> 微调修剪后的模型
  • 修剪模型 -> 从头开始​​训练修剪后的模型

  NNI支持上述所有的方式,本节以第一种方法为例来展示NNI的用法。我们使用一个简单的模型TorchModel(类似LeNet,只是最后多了几个relu函数和MaxPool层),此模型在MNIST数据集上进行了预训练。

2.1 加载并训练模型

import torch
import torch.nn.functional as F
from torch.optim import SGD

from nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device

model = TorchModel().to(device)
model
TorchModel(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)
# 定义优化器和损失函数
optimizer = SGD(model.parameters(), 1e-2)
criterion = F.nll_loss

# 训练并评估模型
for epoch in range(3):
    trainer(model, optimizer, criterion)
    evaluator(model)
Average test loss: 0.7821, Accuracy: 7228/10000 (72%)
Average test loss: 0.2444, Accuracy: 9262/10000 (93%)
Average test loss: 0.1760, Accuracy: 9493/10000 (95%)
# 1.查看原始模型参数量
print('Original model paramater number: ', sum([param.numel() for param in model.parameters()]))

# 2. 测试原模型的推理速度
import time
start = time.time()
model(torch.rand(128, 1, 28, 28).to(device))
print('Original Model - Elapsed Time : ', time.time() - start)
Original model paramater number:  44426
Original Model - Elapsed Time :  2.3036391735076904

2.2 模型剪枝

  通常,修剪器需要原始模型和config_list作为其输入,配置config_list的详细信息,请参阅压缩配置规范。下面我们使用L1NormPruner减掉所有全连接层(除了fc3)和卷积层50%的参数:

from nni.compression.pruning import L1NormPruner

# 1. 定义剪枝器和剪枝配置信息
config_list = [{
    'op_types': ['Linear', 'Conv2d'],
    'exclude_op_names': ['fc3'],
    'sparse_ratio': 0.5
}]
pruner = L1NormPruner(model, config_list)
model
TorchModel(
  (conv1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1)
    (_nni_wrapper): ModuleWrapper(module=Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1)), module_name=conv1)
  )
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (_nni_wrapper): ModuleWrapper(module=Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)), module_name=conv2)
  )
  (fc1): Linear(
    in_features=256, out_features=120, bias=True
    (_nni_wrapper): ModuleWrapper(module=Linear(in_features=256, out_features=120, bias=True), module_name=fc1)
  )
  (fc2): Linear(
    in_features=120, out_features=84, bias=True
    (_nni_wrapper): ModuleWrapper(module=Linear(in_features=120, out_features=84, bias=True), module_name=fc2)
  )
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)
# 2.压缩模型并生成剪枝掩码,模拟剪枝的效果(compress方法不会永久改变模型的结构)
# 此掩码定义了哪些权重应该被设置为零(即剪枝)
_, masks = pruner.compress()

# 3.打印掩码的稀疏度
for name, mask in masks.items():
    print(name, ' sparsity : ', '{:.2}'.format(mask['weight'].sum() / mask['weight'].numel()))
fc1  sparsity :  0.5
conv1  sparsity :  0.5
conv2  sparsity :  0.5
fc2  sparsity :  0.5

2.3 模型加速(剪枝永久化)

  剪枝器(如 L1NormPruner)在应用剪枝时,使用了权重掩模(mask)来模拟剪枝效果。这些掩模并不会真正改变模型的结构,而是通过将权重的某些部分“置零”来模拟稀疏性。这么做可以测试稀疏性对模型性能(如准确率)的影响,但并未真正减少模型参数量或实现运行时加速,因被掩蔽的部分仍然要参与计算。

  为了让剪枝后的模型达到真正加速的效果,需要将模块替换为修剪后的模块,使剪枝效果永久化(类似PyTorch剪枝模块pruneremove函数的效果,详见《datawhale11月组队学习 模型压缩技术2:PyTorch模型剪枝教程》)。

  ModelSpeedup 是 NNI 提供的工具,用来实现剪枝后的模型加速,其主要作用是:

  • 形状推断: 自动根据掩模(mask)推断哪些层需要调整形状,并将其与模型拓扑结合应用。
  • 掩模传播: 掩模不仅仅对当前层生效,还会沿着模型的前后传播到相关联的层。
  • 剪枝永久化: 替换掉原始模型的相关层,生成一个真正更小、更高效的模型。这包括:
    • 使用较小的层替换粗粒度掩模(减少参数量和计算量);
    • 使用用稀疏内核替换细粒度掩模(对稀疏性进行专门优化)。

  粗粒度剪枝(如通道剪枝)通常会改变层的权重形状或输入/输出张量的维度。由于层之间存在连接性(拓扑结构),当一个层的形状改变时,与之相连的层可能也需要调整。例如:如果一个卷积层剪除了某些输出通道,那么连接到该层输出的下一个卷积层的输入通道数也需要相应减少。
NNI会利用 PyTorch 的 torch.fx 进行模型跟踪,获取模型的计算图(拓扑结构),并自动推断各个模块需要如何调整形状。

  如果模型之前被剪枝器包裹过(例如用于剪枝过程中插入了钩子或逻辑),需要在加速之前还原为原始未包裹的模型。然后再利用 ModelSpeedup 工具对模型进行实际的加速。

from nni.compression.speedup import ModelSpeedup

pruner.unwrap_model()
ModelSpeedup(model, torch.rand(3, 1, 28, 28).to(device), masks).speedup_model()
model
TorchModel(
  (conv1): Conv2d(1, 3, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(3, 8, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=128, out_features=60, bias=True)
  (fc2): Linear(in_features=60, out_features=42, bias=True)
  (fc3): Linear(in_features=42, out_features=10, bias=True)
  (relu1): ReLU()
  (relu2): ReLU()
  (relu3): ReLU()
  (relu4): ReLU()
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)
ModelSpeedup参数 类型 说明
model torch.nn.Module 用户需要加速的模型。
dummy_input 任意类型 提供一个输入样本,用来帮助推断模型的形状变化(此处假设输入是一个形状为 (3, 1, 28, 28) 的张量)。
masks_or_file 任意类型 剪枝过程中生成的掩模,用来标识模型的稀疏性。
map_location 任意类型 掩模所存放的设备,与 torch.load 中的 map_location 参数相同。
batch_dim 整数 dummy_input 中的批量维度索引。
batch_size 整数 dummy_input 的批量大小。
customized_mask_updaters List[MaskUpdater]None 一个 MaskUpdater 列表。NNI 会根据前向和后向传播的数值分布自动推断稀疏性。用户也可以自定义推断规则
customized_replacers List[Replacer]None 一个 Replacer 列表,用于将原始模块替换为压缩模块。用户可以通过自定义 Replacer 来定义替换逻辑。
graph_module torch.fx.GraphModuleNone 如果 ModelSpeedup 的默认具体跟踪逻辑无法满足需求,用户可以直接传入 torch.fx.GraphModule
logger logging.LoggerNone 设置一个日志记录器。如果为 None,NNI 将使用默认的日志记录器。
print('Pruned model paramater number: ', sum([param.numel() for param in model.parameters()]))

start = time.time()
model(torch.rand(128, 1, 28, 28).to(device))
print('Speedup Model - Elapsed Time : ', time.time() - start)
Speedup Model - Elapsed Time :  0.09416508674621582

  在PyTorch中,当前实现只能替换整个模块(module)。如果需要替换模型中的某个函数(function),当前的实现不支持。作为一种变通方法,可以将需要替换的函数转换为PyTorch模块,这样就可以使用现有的替换机制。

2.4 微调压缩模型

optimizer = SGD(model.parameters(), 1e-2)
for epoch in range(3):
    trainer(model, optimizer, criterion)

2.5 Slim Pruner测试

  我们在cifar10数据集上使用Slim Pruner对resnet18模型进行剪枝测试,测量了在不同稀疏比下剪枝模型的延迟和准确率,结果如下:

在这里插入图片描述

三、 使用NNI3.0进行Bert压缩(剪枝、蒸馏)

参考《Pruning Bert on Task MNLI》

  本章结合剪枝、蒸馏两种模型压缩技术,以及新的、更强大的剪枝加速工具对Bert模型进行了剪枝,这大大减少了模型的大小。整个剪枝过程分为三个步骤:

  1. 剪枝注意力层(attention layers)。由于修剪后模型性能下降,使用动态蒸馏(dynamic distiller)进行微调,恢复模型性能(将修剪前后的模型分别作为教师模型和学生模型,将二者Transformer Block每一层都进行对齐蒸馏,实现跨层知识传递);
  2. 剪枝前馈层(feed forward layers),进一步减少模型的参数量,然后同样使用动态蒸馏进行微调;
  3. 剪枝嵌入层(embedding layers)。剪枝后学生模型的Transformer Block维度和教师模型不一致,不能再使用动态蒸馏方法。此时要使用自适应蒸馏(adapt_distiller),它会将学生模型和教师模型的每个Transformer Block的输出层添加一个线性层,对齐二者的维度,进行蒸馏(也就是只蒸馏二者的输出层)。

  在每个步骤中,首先使用剪枝器进行模拟剪枝,生成与模块剪枝目标(权重、输入、输出)对应的掩码。然后进入加速阶段,使用稀疏传播来探索局部掩码导致的全局冗余,接着通过替换模型中的子模块将原始模型修改为更小的模型,最后使用蒸馏器来帮助恢复模型的准确率。

  另外在训练时需要用NNI的trace功能包装transformers.Trainer,以追踪初始化参数,这是因为NNI需要在训练期间重新创建训练器,以实现剪枝和蒸馏的感知。

  本教程使用的是bert-base-uncased模型,任务来自于GLUE Benchmark.中的MNLI 。GLUE榜单包含了9个句子级别的分类任务,分别是:

  • CoLA (Corpus of Linguistic Acceptability) :鉴别一个句子是否语法正确
  • MNLI (Multi-Genre Natural Language Inference) :给定一对句子(一个前提句和一个假设句),模型需要判断两者之间的关系属于以下三种之一:
    • Entailment(蕴含):假设句可以从前提句推导出来。
    • Contradiction(矛盾):假设句与前提句相矛盾。
    • Neutral(中性):假设句与前提句既不蕴含也不矛盾。
  • MRPC (Microsoft Research Paraphrase Corpus) :判断两个句子是否互为paraphrases
  • QNLI (Question-answering Natural Language Inference) :判断第2句是否包含第1句问题的答案。
  • QQP (Quora Question Pairs2) :判断两个问句是否语义相同
  • RTE (Recognizing Textual Entailment):判断一个句子是否与假设成entail关系。
  • SST-2 (Stanford Sentiment Treebank) :判断一个句子的情感正负向
  • STS-B (Semantic Textual Similarity Benchmark) :判断两个句子的相似性(分数为1-5分)。
  • WNLI (Winograd Natural Language Inference) :判断包含匿名代词的句子与将该代词替换后的句子是否具有蕴含关系。

有关GLUE Benchmark的更多分析详见《微调预训练模型进行文本分类》

3.1 数据预处理

from __future__ import annotations

from pathlib import Path

import numpy as np

import torch
from torch.utils.data import ConcatDataset

import nni

from datasets import load_dataset, load_metric
from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, EvalPrediction
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments

定义任务和模型:

# 需要注意的是,STS-B是一个回归问题,MNLI是一个3分类问题,其它都是二分类。
task_name = 'mnli'
def build_model(pretrained_model_name_or_path: str, task_name: str):
    is_regression = task_name == 'stsb'
    num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
    model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
    return model

准备GLUE训练和验证数据集,如果任务有多个验证数据集,则通过ConcatDataset合并数据集。

def prepare_datasets(task_name: str, tokenizer: BertTokenizerFast, cache_dir: str):
	
	# 不同任务的输入结构不同,因此我们定义下面这个dict
    task_to_keys = {
        'cola': ('sentence', None),
        'mnli': ('premise', 'hypothesis'),
        'mrpc': ('sentence1', 'sentence2'),
        'qnli': ('question', 'sentence'),
        'qqp': ('question1', 'question2'),
        'rte': ('sentence1', 'sentence2'),
        'sst2': ('sentence', None),
        'stsb': ('sentence1', 'sentence2'),
        'wnli': ('sentence1', 'sentence2'),
    }
    sentence1_key, sentence2_key = task_to_keys[task_name]

    # 定义与处理函数
    def preprocess_function(examples):
        # args 变量区分了单句输入和双句输入的情况
        args = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*args, padding=False, max_length=128, truncation=True)

        if 'label' in examples:
            # 如果数据包含标签(label),则将其重命名为 labels,以适配模型的输入要求
            result['labels'] = examples['label']
        return result
	
	# 从 GLUE 数据集中加载对应任务的数据
    raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
    # 检查数据集中是否包含测试集(没有标签),如果有则移除。
    for key in list(raw_datasets.keys()):
        if 'test' in key:
            raw_datasets.pop(key)
	# 批处理后,移除原始数据集中所有的列,只保留分词后的结果和标签列
    processed_datasets = raw_datasets.map(preprocess_function, batched=True,
                                          remove_columns=raw_datasets['train'].column_names)
	
	# 获取训练集和验证集
    train_dataset = processed_datasets['train']
    if task_name == 'mnli':
        validation_datasets = {
            'validation_matched': processed_datasets['validation_matched'],
            'validation_mismatched': processed_datasets['validation_mismatched']
        }
    else:
        validation_datasets = {
            'validation': processed_datasets['validation']
        }

    return train_dataset, validation_datasets

MNLI 的验证集分为两部分:

  • validation_matched:前提句和假设句来自相同领域。
  • validation_mismatched:前提句和假设句来自不同领域。

其他任务只有一个验证集,因此直接提取。

3.2 训练模型

准备训练器,注意Trainer类是由nni.trace包装的。

def prepare_traced_trainer(model, task_name, load_best_model_at_end=False):
	# 加载 GLUE数据集中任务对应的评估指标,只有stsb是回归任务
    is_regression = task_name == 'stsb'
    metric = load_metric('glue', task_name)

    def compute_metrics(p: EvalPrediction):
    	# 如果p.predictions是元组,就取第一个结果
        preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
        # 回归任务直接去掉多余的维度,保留连续值结果(np.squeeze)
        preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
        result = metric.compute(predictions=preds, references=p.label_ids)
        # 优先返回 F1分数,如果任务没有 f1,则返回 accuracy,都没有返回默认值 0.0
        result['default'] = result.get('f1', result.get('accuracy', 0.))
        return result

    tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    train_dataset, validation_datasets = prepare_datasets(task_name, tokenizer, None)
    # MNLI有两个验证集,则将其合并
    merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()])
    # 创建数据整理器,进行动态填充
    data_collator = DataCollatorWithPadding(tokenizer)
    
    training_args = TrainingArguments(output_dir='./output/trainer',
                                      do_train=True,
                                      do_eval=True,
                                      evaluation_strategy='steps',
                                      per_device_train_batch_size=32,
                                      per_device_eval_batch_size=32,
                                      num_train_epochs=3,
                                      dataloader_num_workers=12,
                                      learning_rate=3e-5,
                                      save_strategy='steps',
                                      save_total_limit=1,
                                      metric_for_best_model='default',
                                      load_best_model_at_end=load_best_model_at_end,
                                      disable_tqdm=True,
                                      optim='adamw_torch',
                                      seed=1024)
                                      
    trainer = nni.trace(Trainer)(model=model,
                                 args=training_args,
                                 data_collator=data_collator,
                                 train_dataset=train_dataset,
                                 eval_dataset=merged_validation_dataset,
                                 tokenizer=tokenizer,
                                 compute_metrics=compute_metrics,)
    return trainer

  如果存在微调后的模型,则直接加载;如果不存在微调后的模型,则使用训练器训练预训练后的模型。

def build_finetuning_model(task_name: str, state_dict_path: str):

    model = build_model('bert-base-uncased', task_name)
    if Path(state_dict_path).exists():
        model.load_state_dict(torch.load(state_dict_path))
    else:
        trainer = prepare_traced_trainer(model, task_name, True)
        trainer.train()
        torch.save(model.state_dict(), state_dict_path)
    return model


Path('./output/bert_finetuned').mkdir(exist_ok=True, parents=True)
build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')

mkdir(exist_ok=True, parents=True):使用 Path 对象创建路径。

  • exist_ok=True:如果目录已存在,不会抛出错误。
  • parents=True:如果父目录不存在,递归创建父目录。

3.3 设置模型蒸馏函数

  动态蒸馏(Dynamic distillation )适用于学生模型和教师模型蒸馏状态维度(distillation states,中间层维度)相匹配的情况。在这种技术中,学生模型可以尝试从多个教师状态中蒸馏知识,并最终选择蒸馏损失最小的教师状态作为蒸馏目标。在本章中,动态蒸馏被应用于加速嵌入层剪枝(embedding pruning)之前。

from nni.compression.distillation import DynamicLayerwiseDistiller, Adaptive1dLayerwiseDistiller
from nni.compression.utils import TransformersEvaluator

def dynamic_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
                      student_trainer: Trainer):

	"""
	student_model:学生模型,通常是简化版本的 BERT 模型。
	teacher_model:教师模型,通常是预训练且性能更强的 BERT 模型。
	student_trainer:Hugging Face Trainer 类实例,负责学生模型的训练和验证。
	"""
	
	# 获取学生模型的 Transformer 层数
    layer_num = len(student_model.bert.encoder.layer)
    config_list = [{
        'op_names': [f'bert.encoder.layer.{i}'],							# 指定当前学生模型中的第i个操作(层)
        'link': [f'bert.encoder.layer.{j}' for j in range(i, layer_num)],   # 关联教师模型的对应层或更深层(跨层链接)
        'lambda': 0.9,														# 权重因子,用于控制蒸馏损失在总损失中的比重。
        'apply_method': 'mse',												# 采用的蒸馏方法(此处为 MSE)
    } for i in range(layer_num)]											
    
    # 最后一层是分类器,对其输出进行蒸馏,采用 KL 散度(常用于分类任务的概率分布蒸馏)
    config_list.append({
        'op_names': ['classifier'],
        'link': ['classifier'],				# 将学生模型和教师模型的分类层连接
        'lambda': 0.9,
        'apply_method': 'kl',
    })

    evaluator = TransformersEvaluator(student_trainer)

    def teacher_predict(batch, teacher_model):
        return teacher_model(**batch)

    return DynamicLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)

函数最终返回动态蒸馏器:

  • student_model:学生模型。
  • config_list:配置列表,定义了各层的蒸馏策略。
  • evaluator:用于评估学生模型性能。
  • teacher_modelteacher_predict:定义教师模型和预测逻辑。
  • origin_loss_lambda:学生模型原始任务损失(如分类交叉熵)在总损失中的权重。

def dynamic_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
                         max_steps: int | None, max_epochs: int | None):

	# 初始化学生模型的训练逻辑,包括数据加载和训练器
    student_trainer = prepare_traced_trainer(student_model, task_name, True)
    
	# 保存教师模型的原始设备和训练模式(训练模式决定模型是否更新)
    ori_teacher_device = teacher_model.device
    training = teacher_model.training
    # 将教师模型切换到学生训练器所在的设备,并设置为评估模式
    teacher_model.to(student_trainer.args.device).eval()
    
	# 初始化动态蒸馏器
    distiller = dynamic_distiller(student_model, teacher_model, student_trainer)
	# 调用 compress 方法,执行蒸馏训练(根据 max_steps 或 max_epochs 控制训练过程)
    distiller.compress(max_steps, max_epochs)
    # 调用 unwrap_model 方法,恢复学生模型为非封装状态。
    distiller.unwrap_model()
    
	# 将教师模型还原到原始设备和训练模式,避免影响其他任务
    teacher_model.to(ori_teacher_device).train(training)

  在嵌入层(输入层)剪枝之后应用自适应蒸馏(Adapt distillation)。这是因为嵌入层修剪之后,学生模型和教师模型的隐藏状态维度不一致,自适应蒸馏器会在每对学生模型和教师模型的蒸馏模块对之间添加一个线性层来对齐输出层(分类层)维度。

  在剪枝了输入嵌入层后,学生模型的嵌入维度可能变小(例如从 768 降低到 384),那么对于每个学生模型的Transformer块,都会添加一个线性层Linear(in_features=384, out_features=768),将维度从384调整到768,以与教师模型的Transformer块输出对齐。

def adapt_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
                    student_trainer: Trainer):
                    
    layer_num = len(student_model.bert.encoder.layer)
    config_list = [{
        'op_names': [f'bert.encoder.layer.{i}'],	# 对学生模型中的每个Transformer 层进行蒸馏
        'lambda': 0.9,								# 蒸馏损失权重系数
        'apply_method': 'mse',
    } for i in range(layer_num)]
    
    # 添加对分类器层的蒸馏
    config_list.append({
        'op_names': ['classifier'],
        'link': ['classifier'],
        'lambda': 0.9,
        'apply_method': 'kl',
    })

    evaluator = TransformersEvaluator(student_trainer)

    def teacher_predict(batch, teacher_model):
        return teacher_model(**batch)

    return Adaptive1dLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)


def adapt_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
                       max_steps: int | None, max_epochs: int | None):
                       
    student_trainer = prepare_traced_trainer(student_model, task_name, True)
    
	# 保存教师模型的原始设备和训练模式(训练模式决定模型是否更新)
    ori_teacher_device = teacher_model.device
    training = teacher_model.training
    teacher_model.to(student_trainer.args.device).eval()

    distiller = adapt_distiller(student_model, teacher_model, student_trainer)
    # 创建一个虚拟输入(dummy_input),模拟实际输入的形状和数据,再将其转移到正确的设备
    dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
    dummy_input = [_.to(student_trainer.args.device) for _ in dummy_input]
     # 开始追踪蒸馏过程的前向传播
    distiller.track_forward(*dummy_input)
	# 进行蒸馏,压缩模型并进行训练,然后解包模型
    distiller.compress(max_steps, max_epochs)
    distiller.unwrap_model()

    teacher_model.to(ori_teacher_device).train(training)
  • dynamic_distiller适用于学生模型和教师模型结构相似(维度一致)的场景,可以实现跨层知识传递——config_list 中的 link 参数可以指定学生模型的某一层与教师模型的某一层相连,动态建立对齐关系。
  • adapt_distiller往往用于学生模型压缩或剪枝之后与教师模型维度不一致的情况,通过添加额外的变换层将其输出层(分类层)与教师模型输出层进行对齐。中间层是不需要对齐的,所以不需要添加跨层连接link。

   track_forward 主要用于Adapt distillation,它需要在适配型蒸馏中显式调用,以确保维度对齐和蒸馏过程的正确性,而 dynamic_distiller 则通过静态的层级配置(如 config_list 中定义的 link)来自动处理这些问题,因此不需要显式调用 track_forward

3.4 修剪注意力层

  1. Attention层修剪:通过使用 MovementPruner 进行注意力层的修剪,生成块稀疏掩码,并在训练后保存修剪后的模型。
  2. 模型加速:通过使用 ModelSpeedup 加速模型,应用注意力层修剪的掩码,并进行动态蒸馏(Distillation)。
from nni.compression.pruning import MovementPruner
from nni.compression.speedup import ModelSpeedup
from nni.compression.utils.external.external_replacer import TransformersAttentionReplacer

def pruning_attn():

	# 创建路径,加载BERT模型
    Path('./output/bert_finetuned/').mkdir(parents=True, exist_ok=True)
    model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
    trainer = prepare_traced_trainer(model, task_name)
    evaluator = TransformersEvaluator(trainer)

    config_list = [{
        'op_types': ['Linear'],
        'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention\.*'],
        'sparse_threshold': 0.1,
        'granularity': [64, 64]  # 64 x 64 block granularity
    }]

    pruner = MovementPruner(model, config_list, evaluator, warmup_step=9000, cooldown_begin_step=36000, regular_scale=10)
    pruner.compress(None, 4)  					# 执行 4 次修剪迭代。
    pruner.unwrap_model()  						# Apply pruning to the model

    masks = pruner.get_masks()  				# Get the pruning masks
    Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
    torch.save(masks, './output/pruning/attn_masks.pth')  # Save masks
    torch.save(model, './output/pruning/attn_masked_model.pth')  # Save the pruned model


pruning_attn()
  • 使用 MovementPruner,指定要修剪的操作类型 (Linear),以及匹配的层名(使用正则表达式 op_names_re 来匹配BERT模型中的注意力层)。
  • 设置 sparse_threshold0.1,表示低于此阈值的权重会被修剪。
  • granularity 参数设置为 [64, 64],即采用 64x64 的块进行修剪(相对于单个注意力头的维度)。

  在修剪之后,使用 ModelSpeedup 进行模型加速。此阶段会根据修剪的掩码对模型进行加速,如果某个注意力头完全被掩盖,则该head会被修剪;如果某个head部分被掩盖,则该head会被恢复。

def speedup_attn():

	# 加载修剪后的模型和掩码
    model = torch.load('./output/pruning/attn_masked_model.pth', map_location='cpu')
    masks = torch.load('./output/pruning/attn_masks.pth', map_location='cpu')
    dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
    replacer = TransformersAttentionReplacer(model)
    ModelSpeedup(model, dummy_input, masks, customized_replacers=[replacer]).speedup_model()

    # 使用动态蒸馏进行微调
    teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
    dynamic_distillation(model, teacher_model, None, 3)
    torch.save(model, './output/pruning/attn_pruned_model.pth')

speedup_attn()
  • 使用 ModelSpeedup 加速模型。通过 TransformersAttentionReplacer 替换掉模型中的注意力层。
  • dummy_input 用于模拟数据输入,确保模型加速过程中不会出现问题。
  • 进行剪枝后,学生模型的表现可能会受到影响。通过动态蒸馏,可以让学生模型逐步适应修剪后的新结构,并进一步优化其在特定任务上的性能。

3.5 修剪前馈层

  这里使用TaylorPruner来修剪前馈层,TaylorPruner 是一种基于泰勒展开的剪枝方法。

  剪枝注意力头后,可能会影响前馈层的表示能力,因此前馈层的稀疏率(即剪枝的比例)会根据被剪枝的注意力头数来调整。被剪枝的头越多,稀疏率越大。TaylorPruner没有调度稀疏比率的功能,我们使用AGPPruner来自适应的调节稀疏率,从而获得更好的剪枝性能。

from nni.compression.pruning import TaylorPruner, AGPPruner
from transformers.models.bert.modeling_bert import BertLayer


def pruning_ffn():

	# 加载一个已剪枝的BERT模型 attn_pruned_model.pth,和一个教师模型
    model: BertForSequenceClassification = torch.load('./output/pruning/attn_pruned_model.pth')
    teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
    # 创建一个配置列表 config_list,用于确定前馈层(BertLayer 中的 intermediate.dense)的稀疏比率。
    # 稀疏比率是根据保留的注意力头数和原始头数来计算的。
    config_list = []
    for name, module in model.named_modules():
        if isinstance(module, BertLayer):
            retained_head_num = module.attention.self.num_attention_heads
            ori_head_num = len(module.attention.pruned_heads) + retained_head_num
            ffn_sparse_ratio = 1 - retained_head_num / ori_head_num / 2
            config_list.append({'op_names': [f'{name}.intermediate.dense'], 'sparse_ratio': ffn_sparse_ratio})

    trainer = prepare_traced_trainer(model, task_name)
    teacher_model.eval().to(trainer.args.device)
    
    # 建立一个蒸馏器用于蒸馏模型,恢复精度。
    distiller = dynamic_distiller(model, teacher_model, trainer)
    # 创建了一个 Taylor剪枝器,剪枝后再通过动态蒸馏进行微调(1000steps)
    taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
    # 创建了一个 AGP剪枝器,在TaylorPruner基础上进一步压缩模型
    # 1000和36分别表示每隔多少步更新一次稀疏度以及总共更新多少次稀疏度
    agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
    agp_pruner.compress(None, 3)
    agp_pruner.unwrap_model()
    distiller.unwrap_teacher_model()

    masks = agp_pruner.get_masks()
    Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
    torch.save(masks, './output/pruning/ffn_masks.pth')
    torch.save(model, './output/pruning/ffn_masked_model.pth')

pruning_ffn()

def speedup_ffn():
    model = torch.load('./output/pruning/ffn_masked_model.pth', map_location='cpu')
    masks = torch.load('./output/pruning/ffn_masks.pth', map_location='cpu')
    dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
    # 使用 ModelSpeedup 类和传入的模型、掩码以及伪输入来加速模型。
    ModelSpeedup(model, dummy_input, masks).speedup_model()

    # 在剪枝后的模型上执行动态蒸馏,以恢复精度。
    teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
    dynamic_distillation(model, teacher_model, None, 3)
    torch.save(model, './output/pruning/ffn_pruned_model.pth')

speedup_ffn()

3.6 修剪嵌入层

  剪枝 Embedding层(嵌入层)时,为了更好地模拟剪枝效果,采用了一种更细致的剪枝策略,包括对 BertAttention 和 BertOutput 层的输出进行掩码(mask)设置,并且通过特定的方式来生成和应用这些掩码。

  为了更好地模拟剪枝效果,我们需要为 BertAttention 和 BertOutput 层注册输出掩码设置(output mask setting)。这些层的输出掩码用于决定哪些连接或参数需要保留,哪些需要剪枝。

  BertAttention 和 BertOutput 是 BERT 模型中的关键层,其中 BertAttention 负责计算注意力机制,而 BertOutput 用于生成最终的输出。

from nni.compression.base.setting import PruningSetting

output_align_setting = {
    '_output_': {
        'align': {
            'module_name': None,
            'target_name': 'weight',
            'dims': [0],
        },
        'apply_method': 'mul',
    }
}
PruningSetting.register('BertAttention', output_align_setting)
PruningSetting.register('BertOutput', output_align_setting)

  就像在剪枝前馈层时一样,剪枝嵌入层时也使用 AGPPruner、TaylorPruner 和 DynamicLayerwiseDistiller 三种技术组合来实现。为了更好的修剪效果模拟,在config_list中设置输出对齐掩码生成,这样相关层的输出掩码将根据嵌入掩码自动调整,从而在剪枝过程中更加一致和有效。

def pruning_embedding():
    model: BertForSequenceClassification = torch.load('./output/pruning/ffn_pruned_model.pth')
    teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')

    sparse_ratio = 0.5
    config_list = [{
        'op_types': ['Embedding'],
        'op_names_re': ['bert\.embeddings.*'],
        'sparse_ratio': sparse_ratio,
        'dependency_group_id': 1,
        'granularity': [-1, 1],
    }, {
        'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention$',
                        'bert\.encoder\.layer\.[0-9]*\.output$'],
        'target_names': ['_output_'],
        'target_settings': {
            '_output_': {
                'align': {
                    'module_name': 'bert.embeddings.word_embeddings',
                    'target_name': 'weight',
                    'dims': [1],
                }
            }
        }
    }, {
        'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention.output.dense',
                        'bert\.encoder\.layer\.[0-9]*\.output.dense'],
        'target_names': ['weight'],
        'target_settings': {
            'weight': {
                'granularity': 'out_channel',
                'align': {
                    'module_name': 'bert.embeddings.word_embeddings',
                    'target_name': 'weight',
                    'dims': [1],
                }
            }
        }
    }]

    trainer = prepare_traced_trainer(model, task_name)
    teacher_model.eval().to(trainer.args.device)
    distiller = dynamic_distiller(model, teacher_model, trainer)
    taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
    agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
    agp_pruner.compress(None, 3)
    agp_pruner.unwrap_model()
    distiller.unwrap_teacher_model()

    masks = agp_pruner.get_masks()
    Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
    torch.save(masks, './output/pruning/embedding_masks.pth')
    torch.save(model, './output/pruning/embedding_masked_model.pth')

pruning_embedding()
# 加速嵌入层修剪后的模型,并进行自适应蒸馏微调
def speedup_embedding():
    model = torch.load('./output/pruning/embedding_masked_model.pth', map_location='cpu')
    masks = torch.load('./output/pruning/embedding_masks.pth', map_location='cpu')
    dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
    ModelSpeedup(model, dummy_input, masks).speedup_model()

    teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
    adapt_distillation(model, teacher_model, None, 4)
    torch.save(model, './output/pruning/embedding_pruned_model.pth')

speedup_embedding()

3.8 模型评估

def evaluate_pruned_model():
    model: BertForSequenceClassification = torch.load('./output/pruning/embedding_pruned_model.pth')
    trainer = prepare_traced_trainer(model, task_name)
    metric = trainer.evaluate()
    pruned_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())

    model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
    ori_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
    print(f'Metric: {metric}\nSparsity: {1 - pruned_num_params / ori_num_params}')

evaluate_pruned_model()
Total Sparsity Embedding Sparsity Encoder Sparsity Pooler Sparsity Acc. (m/mm avg.)
0.% 0.% 0.% 0.% 84.95%
57.76% 33.33% (15.89M) 64.78% (29.96M) 33.33% (0.39M) 84.42%
68.31% (34.70M) 50.00% (11.92M) 73.57% (22.48M) 50.00% (0.30M) 83.33%
70.95% (31.81M) 33.33% (15.89M) 81.75% (15.52M) 33.33% (0.39M) 83.79%
78.20% (23.86M) 50.00% (11.92M) 86.31% (11.65M) 50.00% (0.30M) 82.53%
81.65% (20.12M) 50.00% (11.92M) 90.71% (7.90M) 50.00% (0.30M) 82.08%
84.32% (17.17M) 50.00% (11.92M) 94.18% (4.95M) 50.00% (0.30M) 81.35%

  原始模型准确率为 84.95%,随着稀疏度逐渐增加至84.32%,模型准确率依旧保持在81%以上,说明通过蒸馏等方法的辅助,模型在剪枝后(结合 AGPPruner 和 TaylorPruner)仍能保持较高的准确度。