知识蒸馏的概念
可以参照NeurIPS2015的论文“Distilling the Knowledge in a Neural Network”了解知识蒸馏的概念。
知识蒸馏的狭义概念就是从复杂模型中迁移知识来提升简单模型的性能。复杂模型称之为教师模型,简单模型称之为学生模型。最近,笔者重温了知识蒸馏的概念,并在CIFAR100数据集上对知识蒸馏进行了验证和实验。
logits,硬目标,软目标的概念:logits指的是网络最后一层的输出概率,硬目标指的是真值标签的one-hot编码,软目标指的是对logits进行softmax之后的概率。
加入温度系数的软目标,为了让softmax之后的概率分布更加软化,Hinton提出了使用了温度参数对logits进行softmax的软化处理,
T为温度,T越大,概率分布更加平缓。
数据集 CIFAR100,是一个经典的图像分类模型,有100个图像类别
数据集直接采用Pytorch定义的官方数据集进行加载
import torchvision
from torchvision import transforms
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])
train_dataset = torchvision.datasets.cifar.CIFAR100(
root = "./dataset/",
train=True,
transform=transform_train,
download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(
root = "./dataset/",
train = False,
transform=transform_test,
download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)
分类模型:采用ResNet50作为教师模型,VGG16作为学生模型。
VGG16网络定义代码
"""vgg in pytorch
[1] Karen Simonyan, Andrew Zisserman
Very Deep Convolutional Networks for Large-Scale Image Recognition.
https://arxiv.org/abs/1409.1556v6
"""
'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn
cfg = {
'A' : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B' : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}
class VGG(nn.Module):
def __init__(self, features, num_class=100):
super().__init__()
self.features = features
self.classifier = nn.Sequential(
nn.Linear(512, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, num_class)
)
def forward(self, x):
output = self.features(x)
output = output.view(output.size()[0], -1)
output = self.classifier(output)
return output
def make_layers(cfg, batch_norm=False):
layers = []
input_channel = 3
for l in cfg:
if l == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
continue
layers += [nn.Conv2d(input_channel, l, kernel_size=3, padding=1)]
if batch_norm:
layers += [nn.BatchNorm2d(l)]
layers += [nn.ReLU(inplace=True)]
input_channel = l
return nn.Sequential(*layers)
def vgg16_bn():
return VGG(make_layers(cfg['D'], batch_norm=True))
ResNet50网络定义代码
"""resnet in pytorch
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
Deep Residual Learning for Image Recognition
https://arxiv.org/abs/1512.03385v1
"""
import torch
import torch.nn as nn
class BasicBlock(nn.Module):
"""Basic Block for resnet 18 and resnet 34
"""
#BasicBlock and BottleNeck block
#have different output size
#we use class attribute expansion
#to distinct
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
#residual function
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
#shortcut
self.shortcut = nn.Sequential()
#the shortcut output dimension is not the same with residual function
#use 1*1 convolution to match the dimension
if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class BottleNeck(nn.Module):
"""Residual block for resnet over 50 layers
"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class ResNet(nn.Module):
def __init__(self, block, num_block, num_classes=100):
super().__init__()
self.in_channels = 64
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True))
#we use a different inputsize than the original paper
#so conv2_x's stride is 1
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride):
"""make resnet layers(by layer i didnt mean this 'layer' was the
same as a neuron netowork layer, ex. conv layer), one layer may
contain more than one residual block
Args:
block: block type, basic block or bottle neck block
out_channels: output depth channel number of this layer
num_blocks: how many blocks per layer
stride: the stride of the first block of this layer
Return:
return a resnet layer
"""
# we have num_block blocks per layer, the first block
# could be 1 or 2, other blocks would always be 1
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
output = self.conv1(x)
output = self.conv2_x(output)
output = self.conv3_x(output)
output = self.conv4_x(output)
output = self.conv5_x(output)
output = self.avg_pool(output)
output = output.view(output.size(0), -1)
output = self.fc(output)
return output
def resnet50():
""" return a ResNet 50 object
"""
return ResNet(BottleNeck, [3, 4, 6, 3])
先单独训练教师模型和学生模型,分别统计教师模型学生模型的精度
损失函数 nn.CrossEntropyLoss()
优化器 torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)
学习率曲线 torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)
epochs = 200
教师模型训练代码
import torch
from torch import nn
from tqdm import tqdm
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from my_resnet import resnet50
def TeacherModel():
""" return a ResNet 50 object
"""
model = resnet50()
return model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])
train_dataset = torchvision.datasets.cifar.CIFAR100(
root = "./dataset/",
train=True,
transform=transform_train,
download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(
root = "./dataset/",
train = False,
transform=transform_test,
download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)
if __name__ == "__main__":
"""
从头训练教师模型
"""
model = TeacherModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)
train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) #learning rate decay
iter_per_epoch = len(train_loader)
epochs = 200
best_acc = 0.0
global_step = 0
for epoch in range(epochs):
model.train()
train_scheduler.step(epoch)
for data, targets in tqdm(train_loader):
data = data.to(device)
targets = targets.to(device)
optimizer.zero_grad()
prediction = model(data)
loss = criterion(prediction, targets)
loss.backward()
optimizer.step()
global_step += 1
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
prediction = model(x)
prediction = prediction.max(1).indices
num_correct += (prediction == y).sum()
num_samples += prediction.size(0)
acc = (num_correct/num_samples).item()
if acc > best_acc:
torch.save(model.state_dict(), './weights/teacher_cifar100/teacher_{}.pth'.format(acc))
best_acc = acc
print("Epoch {}: 当前模型最佳精度为:{:.4f}".format(epoch, best_acc))
"""
教师模型
Epoch 199: 当前模型最佳精度为:0.7840
"""
教师模型的分类精度为78.40%
学生模型的训练代码
import torch
from torch import nn
from tqdm import tqdm
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from my_vgg import vgg16_bn
def StudentModel():
model = vgg16_bn()
return model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])
train_dataset = torchvision.datasets.cifar.CIFAR100(
root = "./dataset/",
train=True,
transform=transform_train,
download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(
root = "./dataset/",
train = False,
transform=transform_test,
download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)
if __name__ == "__main__":
"""
从头训练学生模型
"""
model = StudentModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)
train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) #learning rate decay
iter_per_epoch = len(train_loader)
epochs = 200
best_acc = 0.0
global_step = 0
for epoch in range(epochs):
model.train()
train_scheduler.step(epoch)
for data, targets in tqdm(train_loader):
data = data.to(device)
targets = targets.to(device)
optimizer.zero_grad()
prediction = model(data)
loss = criterion(prediction, targets)
loss.backward()
optimizer.step()
global_step += 1
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
prediction = model(x)
prediction = prediction.max(1).indices
num_correct += (prediction == y).sum()
num_samples += prediction.size(0)
acc = (num_correct/num_samples).item()
if acc > best_acc:
torch.save(model.state_dict(), './weights/student_cifar100_vgg16/student_{}.pth'.format(acc))
best_acc = acc
print("Epoch {}: 当前模型最佳精度为:{:.4f}".format(epoch, best_acc))
"""
学生模型 VGG16
Epoch 199: 当前模型最佳精度为:0.7121
"""
学生模型的训练精度为71.21%
教师-学生模型蒸馏训练,学生损失为CE交叉熵损失,蒸馏损失为KL散度损失
重点一:蒸馏学生损失loss=(1-alpha) * T * T * soft_loss + alpha * hard_loss,其中alpha为权重参数,T为Temperature温度参数,用于软目标化
具体可参见 bilibili视频
重点二:蒸馏损失的计算方式,student_predictions需要处以温度参数后进行F.log_softmax变成软目标,teacher_predictions需要处以温度参数
distillation_loss = soft_loss(F.log_softmax(student_predictions / Temp, dim=1), F.softmax(teacher_predictions / Temp, dim=1))
重点三:教师模型需要eval(), 得到教师模型输出需要 with torch.no_grad()和.detach()
with torch.no_grad():
teacher_predictions = teacher_model(data)
teacher_predictions = teacher_predictions.detach()
重点四:损失权重参数alpha和温度系数T的设定,笔者参照bilibili视频的设定,设置alpha为0.3,温度系数T为4
蒸馏训练代码
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
from teacher_cifar100 import TeacherModel
from vgg_student_cifar100 import StudentModel
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ToTensor(),
transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)
])
#load MNIST datasets
train_dataset = torchvision.datasets.cifar.CIFAR100(
root = "./dataset/",
train=True,
transform=transform_train,
download=True
)
test_dataset = torchvision.datasets.cifar.CIFAR100(
root = "./dataset/",
train = False,
transform=transform_test,
download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=128, num_workers=4, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, num_workers=4, shuffle=False)
if __name__ == "__main__":
"""
从头训练教师模型
"""
teacher_model = TeacherModel().to(device).eval()
teacher_model.load_state_dict(torch.load("./weights/teacher_cifar100/teacher_0.7839999794960022.pth"))
student_model = StudentModel().to(device)
Temp = 4
alpha = 0.3
hard_loss = nn.CrossEntropyLoss()
soft_loss = nn.KLDivLoss(reduction='batchmean')
optimizer = torch.optim.SGD(student_model.parameters(), lr=0.02, momentum=0.9, weight_decay=5e-4)
train_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2) #learning rate decay
iter_per_epoch = len(train_loader)
epochs = 200
best_acc = 0.0
global_step = 0
for epoch in range(epochs):
student_model.train()
train_scheduler.step(epoch)
for data, targets in tqdm(train_loader):
data = data.to(device)
targets = targets.to(device)
optimizer.zero_grad()
#教师预测
with torch.no_grad():
teacher_predictions = teacher_model(data)
teacher_predictions = teacher_predictions.detach() #参照https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757
student_predictions = student_model(data)
student_loss = hard_loss(student_predictions, targets)
distillation_loss = soft_loss(
F.log_softmax(student_predictions / Temp, dim=1), ##参照https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757
F.softmax(teacher_predictions / Temp, dim=1)
)
loss = (1 - alpha) * Temp * Temp * distillation_loss + alpha * student_loss #T2 参照https://www.bilibili.com/video/BV1Go4y1u72L/?spm_id_from=333.337.search-card.all.click&vd_source=e71c4eae27444c44f2de6239f04c4757
loss.backward()
optimizer.step()
global_step += 1
student_model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for x, y in test_loader:
x = x.to(device)
y = y.to(device)
prediction = student_model(x)
prediction = prediction.max(1).indices
num_correct += (prediction == y).sum()
num_samples += prediction.size(0)
acc = (num_correct/num_samples).item()
if acc > best_acc:
torch.save(student_model.state_dict(), './weights/knowledge_distillation_cifar100_vgg16/student_{}.pth'.format(acc))
best_acc = acc
print("Epoch {}: 当前模型最佳精度为:{:.4f}".format(epoch, best_acc))
"""
蒸馏学生模型 ResNet50 --> VGG16
ResNet50 当前模型最佳精度为:0.7840
VGG16 当前模型最佳精度为:0.7121
Temp = 4 alpha = 0.3 Acc Epoch 199: 当前模型最佳精度为:0.7388
"""
知识蒸馏实验对比结果
模型 | 网络结构 | 分类精度 |
---|---|---|
学生模型 | VGG16 | 71.21% |
教师模型 | ResNet50 | 78.40% |
蒸馏学生模型 | VGG16 | 73.88% |