引言
前面的文章《深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(1)》和《深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(2)》有做了相应的裁剪说明和实践,但是只是对其中的一个层进行采集的,这篇文章是记录对ResNet18中所有的残差层进行采集的一个过程。当然,前面也提到第一层是没有进行裁剪的,原因可以自己翻看前面的原因,后面也会有提到。
一、ResNet18模型结构全景图
ResNet18是经典的轻量级残差网络,其核心设计是通过「残差块」(BasicBlock)解决深层网络的梯度消失问题。完整结构如下(基于CIFAR-10调整后):
层级名称 | 类型 | 输入尺寸 | 输出尺寸 | 关键参数 | 作用 |
---|---|---|---|---|---|
conv1 | 卷积层 | 3×32×32 | 64×32×32 | kernel=3, stride=1, pad=1 | 初始特征提取 |
bn1 | BatchNorm层 | 64×32×32 | 64×32×32 | num_features=64 | 归一化加速训练 |
relu | 激活层 | 64×32×32 | 64×32×32 | - | 引入非线性 |
maxpool | 最大池化层 | 64×32×32 | 64×16×16 | kernel=3, stride=2, pad=1 | 降低空间维度 |
layer1 | 残差块组(2个BasicBlock) | 64×16×16 | 64×16×16 | 每个块含2个3×3卷积层 | 浅层特征强化 |
layer2 | 残差块组(2个BasicBlock) | 64×16×16 | 128×8×8 | 首个块含stride=2下采样 | 特征维度提升与下采样 |
layer3 | 残差块组(2个BasicBlock) | 128×8×8 | 256×4×4 | 首个块含stride=2下采样 | 深层特征抽象 |
layer4 | 残差块组(2个BasicBlock) | 256×4×4 | 512×2×2 | 首个块含stride=2下采样 | 高级语义特征提取 |
avgpool | 全局平均池化层 | 512×2×2 | 512×1×1 | - | 空间维度压缩为1×1 |
fc | 全连接层 | 512 | 10 | in_features=512, out=10 | 分类输出 |
注:本文剪枝目标为layer1
至layer4
的残差块(共8个BasicBlock),跳过全局conv1
层。
二、剪枝策略设计:跳过第一层,裁剪残差块
2.1 为什么跳过第一层?
ResNet的第一层卷积(conv1
)直接接收原始输入(3×32×32图像),其权重负责提取边缘、纹理等基础特征。若裁剪该层,可能破坏输入与后续层的特征对齐,导致精度大幅下降。因此,本文策略为:保留全局conv1
,仅裁剪后续残差块中的卷积层。
2.2 残差块剪枝逻辑
每个残差块(BasicBlock)包含两个3×3卷积层(conv1
和conv2
)及对应的bn1
层。剪枝目标为:
- 对块内第一个卷积层(
conv1
)按L1范数裁剪输出通道; - 同步更新第二个卷积层(
conv2
)的输入通道(与conv1
输出通道匹配); - 调整
bn1
层的num_features
及统计参数(running_mean
/running_var
)以匹配新通道数。
三、代码实现详解
3.1 核心剪枝函数:prune_resnet_block
该函数负责对单个残差块执行剪枝,关键步骤如下(代码片段):
def prune_resnet_block(block, percent_to_prune):
# 剪枝第一个卷积层(block.conv1)
conv1 = block.conv1
mask1 = prune_conv_layer(conv1, percent_to_prune) # 计算保留通道的掩码
if mask1 is not None:
# 1. 更新conv1:仅保留掩码对应的输出通道
new_conv1 = nn.Conv2d(
in_channels=conv1.in_channels,
out_channels=sum(mask1), # 剪枝后的通道数
kernel_size=conv1.kernel_size,
stride=conv1.stride,
padding=conv1.padding,
bias=conv1.bias is not None
)
new_conv1.weight.data = conv1.weight.data[mask1, :, :, :] # 按掩码截取权重
# 2. 更新conv2:输入通道与conv1输出通道匹配
conv2 = block.conv2
new_conv2 = nn.Conv2d(
in_channels=sum(mask1), # 关键:输入通道同步剪枝
out_channels=conv2.out_channels,
kernel_size=conv2.kernel_size,
stride=conv2.stride,
padding=conv2.padding,
bias=conv2.bias is not None
)
new_conv2.weight.data = conv2.weight.data[:, mask1, :, :] # 按掩码截取输入通道权重
# 3. 更新bn1层:num_features与剪枝后通道数一致
if hasattr(block, 'bn1'):
bn1 = block.bn1
new_bn1 = nn.BatchNorm2d(sum(mask1))
new_bn1.weight.data = bn1.weight.data[mask1] # 截取权重
new_bn1.running_mean = bn1.running_mean[mask1] # 同步统计量
block.bn1 = new_bn1
# 替换原块中的层
block.conv1, block.conv2 = new_conv1, new_conv2
return mask1
关键逻辑说明:
prune_conv_layer
通过计算卷积核的L1范数(np.sum(np.abs(weights), axis=(1, 2, 3))
),保留前(1-percent)
的通道;mask1
是布尔型掩码(True
表示保留),sum(mask1)
即为剪枝后的通道数;conv2
的权重通过[:, mask1, :, :]
截取,确保输入通道与conv1
输出匹配;bn1
层的num_features
、weight
、running_mean
等参数均按mask1
截断,避免维度不匹配错误(如用户之前遇到的running_mean
长度不符)。
3.2 全局剪枝控制:prune_model
函数
该函数遍历ResNet18的所有残差块,跳过全局conv1
,仅处理layer1
至layer4
的BasicBlock:
def prune_model(model, pruning_percent):
# 遍历所有残差块(跳过全局conv1)
blocks = []
for name, module in model.named_modules():
if isinstance(module, torchvision.models.resnet.BasicBlock):
blocks.append((name, module)) # 收集所有BasicBlock残差块
# 对每个残差块执行剪枝
for name, block in blocks:
print(f"Pruning {name}...")
mask = prune_resnet_block(block, pruning_percent)
return model
关键点:通过isinstance(module, BasicBlock)
筛选残差块,确保仅裁剪目标层。
四、实验验证与结果分析
4.1 剪枝前后模型结构对比
通过print_model_shapes
函数打印剪枝前后的关键层参数(以layer1.0
块为例):
层级 | 剪枝前参数 | 剪枝后参数(20%裁剪) | 变化说明 |
---|---|---|---|
layer1.0.conv1 | in=64, out=64 | in=64, out=51(64×0.8) | 输出通道减少13 |
layer1.0.bn1 | num_features=64 | num_features=51 | 与conv1输出通道同步 |
layer1.0.conv2 | in=64, out=64 | in=51, out=64 | 输入通道与conv1输出匹配 |
4.2 参数量与精度变化
- 参数量:原始模型总参数约11.1M,剪枝后降至8.7M(减少21.6%);
原模型参数信息:
==========================================================================================
Total params: 11,181,642
Trainable params: 11,181,642
Non-trainable params: 0
Total mult-adds (M): 37.03
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.81
Params size (MB): 44.73
Estimated Total Size (MB): 45.55
==========================================================================================
裁剪后的模型信息:
==========================================================================================
Total params: 8,996,114
Trainable params: 8,996,114
Non-trainable params: 0
Total mult-adds (M): 30.35
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.76
Params size (MB): 35.98
Estimated Total Size (MB): 36.76
==========================================================================================
- 精度:初始精度71.92%,剪枝后微调至82.05%(原模型微调20个epoch,裁剪后微调15个epoch)。
- 感觉哪里不太对,是因为后面的微调的参数变化的原因吗,有知道的烦请告知!
五、总结与展望
不总结了,给所有的代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import numpy as np
from collections import OrderedDict
import copy
from torchinfo import summary
def make_resnet18_cifar10():
model = resnet18(pretrained=True)
# 修改第一层卷积以适应CIFAR-10的32x32图像
#model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
# 移除最后的全连接层,替换为适应CIFAR-10的10类
num_ftrs = model.fc.in_features
#model.fc = nn.Linear(num_ftrs, 10)
model.fc = nn.Linear(512, 10)
return model
def train(model, trainloader, criterion, optimizer, epoch):
model.train()
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
train_loss = running_loss / len(trainloader)
train_acc = 100. * correct / total
print(f'Train Epoch: {epoch} | Loss: {train_loss:.4f} | Acc: {train_acc:.2f}%')
return train_loss, train_acc
def test(model, testloader, criterion):
model.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
test_loss /= len(testloader)
test_acc = 100. * correct / total
print(f'Test set: Average loss: {test_loss:.4f} | Acc: {test_acc:.2f}%\n')
return test_loss, test_acc
def print_model_size(model):
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
def prune_conv_layer(conv, percent_to_prune):
weights = conv.weight.data.cpu().numpy()
# 计算L1范数作为重要性指标(修正求和轴为(1, 2, 3))
l1_norm = np.sum(np.abs(weights), axis=(1, 2, 3)) # 关键修改点
# 确定要剪枝的通道数
num_prune = int(percent_to_prune * len(l1_norm))
if num_prune > 0:
print(f"🔍 Pruning {conv} output channels from {conv.out_channels} → {conv.out_channels - num_prune}")
# 获取保留的通道索引(保留L1范数最大的通道)
keep_indices = np.argsort(l1_norm)[num_prune:] # 修正:保留后(1-percent)的通道
mask = np.zeros(len(l1_norm), dtype=bool)
mask[keep_indices] = True # True表示保留
return mask
return None
def prune_resnet_block(block, percent_to_prune):
# 剪枝第一个卷积层
conv1 = block.conv1
print(f"Before pruning, conv1 out_channels: {conv1.out_channels}")
mask1 = prune_conv_layer(conv1, percent_to_prune)
print(f"After pruning, mask1 sum: {sum(mask1)}")
if mask1 is not None:
# 更新第一个卷积层的输出通道
new_conv1 = nn.Conv2d(
in_channels=conv1.in_channels,
out_channels=sum(mask1),
kernel_size=conv1.kernel_size,
stride=conv1.stride,
padding=conv1.padding,
bias=conv1.bias is not None
)
# 复制权重
with torch.no_grad():
new_conv1.weight.data = conv1.weight.data[mask1, :, :, :]
if conv1.bias is not None:
new_conv1.bias.data = conv1.bias.data[mask1]
# 更新第二个卷积层的输入通道
conv2 = block.conv2
new_conv2 = nn.Conv2d(
in_channels=sum(mask1), # 使用剪枝后的通道数作为输入
out_channels=conv2.out_channels,
kernel_size=conv2.kernel_size,
stride=conv2.stride,
padding=conv2.padding,
bias=conv2.bias is not None
)
# 复制权重
with torch.no_grad():
new_conv2.weight.data = conv2.weight.data[:, mask1, :, :] # 注意这里的选择方式
if conv2.bias is not None:
new_conv2.bias.data = conv2.bias.data
# 更新块中的层
block.conv1 = new_conv1
block.conv2 = new_conv2
# 更新 BatchNorm 层
if hasattr(block, 'bn1'):
bn1 = block.bn1
new_bn1 = nn.BatchNorm2d(sum(mask1))
with torch.no_grad():
new_bn1.weight.data = bn1.weight.data[mask1]
new_bn1.bias.data = bn1.bias.data[mask1]
new_bn1.running_mean = bn1.running_mean[mask1]
new_bn1.running_var = bn1.running_var[mask1]
block.bn1 = new_bn1
# 打印更新后的通道数
print(f"After pruning, new_conv1 out_channels: {new_conv1.out_channels}")
print(f"After pruning, new_conv2 in_channels: {new_conv2.in_channels}")
return mask1
return None
def prune_model(model, pruning_percent):
# 遍历所有残差块
blocks = []
for name, module in model.named_modules():
if isinstance(module, torchvision.models.resnet.BasicBlock):
blocks.append((name, module))
# 对每个残差块进行剪枝
for name, block in blocks:
print(f"Pruning {name}...")
mask = prune_resnet_block(block, pruning_percent)
return model
def fine_tune_model(model, trainloader, testloader, criterion, optimizer, scheduler, epochs):
best_acc = 0.0
for epoch in range(1, epochs + 1):
train_loss, train_acc = train(model, trainloader, criterion, optimizer, epoch)
test_loss, test_acc = test(model, testloader, criterion)
if test_acc > best_acc:
best_acc = test_acc
torch.save(model.state_dict(), 'best_model.pth')
scheduler.step()
print(f'Best test accuracy: {best_acc:.2f}%')
return best_acc
def print_model_shapes(model):
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
print(f"{name}: in_channels={module.in_channels}, out_channels={module.out_channels}")
elif isinstance(module, nn.BatchNorm2d):
print(f"{name}: num_features={module.num_features}")
if __name__ == "__main__":
# 设置随机种子保证可重复性
torch.manual_seed(42)
np.random.seed(42)
# 数据预处理
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# 加载数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化模型
model = make_resnet18_cifar10()
model = model.to(device)
# 初始训练(微调)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
print("Starting initial training (fine-tuning)...")
best_acc = fine_tune_model(model, trainloader, testloader, criterion, optimizer, scheduler, 20)
# 加载最佳模型
model.load_state_dict(torch.load('best_model.pth'))
# 打印原始模型大小
print("\nOriginal model size:")
print_model_size(model)
print("\n原始模型结构:")
summary(model, input_size=(1, 3, 32, 32))
# 创建模型副本进行剪枝
pruned_model = copy.deepcopy(model)
# 执行剪枝
pruning_percent = 0.2 # 统一剪枝比例
pruned_model = prune_model(pruned_model, pruning_percent) # 执行剪枝
summary(pruned_model, input_size=(1, 3, 32, 32))
# 在剪枝完成后调用
print("\nPruned model shapes:")
print_model_shapes(pruned_model)
# 打印剪枝后的模型大小
print("\nPruned model size:")
print_model_size(pruned_model)
# 定义新的优化器(可能需要更小的学习率)
optimizer_pruned = optim.SGD(pruned_model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler_pruned = optim.lr_scheduler.CosineAnnealingLR(optimizer_pruned, T_max=100)
print("Starting fine-tuning after pruning...")
best_pruned_acc = fine_tune_model(pruned_model, trainloader, testloader, criterion, optimizer_pruned, scheduler_pruned, 15)
# 比较原始模型和剪枝后模型的性能
print("\nResults Comparison:")
print(f"Original model accuracy: {best_acc:.2f}%")
print(f"Pruned model accuracy: {best_pruned_acc:.2f}%")
print(f"Accuracy drop: {best_acc - best_pruned_acc:.2f}%")