DINO 的核心是通过教师-学生模型(Teacher-Student)的自蒸馏框架,让两个不同视角(不同数据增强)的同一图像的表征尽可能一致,是一种自监督模型。
教师-学生架构(Teacher-Student)
学生网络(Student):通过梯度下降更新,输入是局部裁剪(Local Crops)(如小尺寸图像块)。
教师网络(Teacher):不更新梯度,其权重是学生网络的指数移动平均(EMA):
其中 λ 是动量系数(如 0.996),控制教师模型的平滑更新。
数据增强策略(Multi-Crop)
DINO 使用不同尺度的裁剪(类似 SwAV):
Global Views(全局视图):2 个大尺寸裁剪(覆盖大部分图像)。
Local Views(局部视图):多个小尺寸裁剪(聚焦局部细节)。
目标是让学生模型预测教师模型对同一图像不同视角的特征。
损失函数(Cross-Entropy Loss)
DINO 采用交叉熵损失,让学生模型的输出分布匹配教师模型的输出分布:
教师模型的输出经过
softmax
归一化(温度参数)。
学生模型的输出也经过
softmax
归一化(温度参数)。
通过调整温度参数,可以控制概率分布的平滑程度。
避免模型坍塌(Avoiding Collapse)
中心化(Centering):对教师模型的输出进行零均值化,防止所有样本预测相同类别:
其中 c 是滑动平均的均值向量。
锐化(Sharpening):使用较低的
softmax
温度(如 0.1),让教师模型的预测更置信,引导学生模型学习更明确的特征。
在DINO框架中,教师模型负责处理全局视图(如完整图像),而学生模型则处理局部视图(如图像裁剪块)。这种设计通过温度调控实现了有效的知识蒸馏:教师模型使用较低的softmax温度(如0.04),使其输出分布更加尖锐和确定,从而为特征学习提供高置信度的指导目标;与此同时,学生模型采用较高的温度(如0.1),使其能够以更平滑的概率分布来捕捉局部视图与全局特征之间的潜在关联。通过这种机制,学生模型能够从局部细节中推断出全局的语义特征,而教师模型提供的精确目标则确保了特征表示的一致性。此外,中心化(centering)和锐化(sharpening)技术的引入进一步防止了模型坍塌,使得这种跨视图的特征匹配能够稳定地收敛到有意义的解。这种基于温度调控的自蒸馏策略,本质上构建了一个动态的师生互动系统:教师不断提供经过"深思熟虑"(低温精确)的特征表示,而学生则通过"广泛探索"(高温平滑)来学习如何从局部信息重建全局理解
代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torchvision.transforms import Compose,RandomCrop,RandomHorizontalFlip,ToTensor,Normalize
from torch.utils.data import DataLoader
from torchvison.datasets import CIFAR10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class VIT(nn.Module):
def __init__(self,output_dim):
super(ViT,self).__init__()
self.vit = timm.create_model('vit_small_patch16_224',pretrained=False,num_classes=output_dim )
def forward(self,x):
return self.vit(x)
def get_dataloader(batch_size):
transform = Compose(
[
Resize(224)
RandomCrop(224,padding=4)
RandomHorizontalFlip()
ToTensor(),
Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
]
)
dataset = CIFAR10(root = './cifar10',train=True,transform=transform,download =True)
return DataLoader(dataset,batch_size=batch_size,shuffle=True)
output_dim =128
gs = ViT(output_dim).to(device)
gt = ViT(output_dim).to(device)
gt.load_state_dict(gs.state_dict())
C =torch.zeros(output_dim,device = device)
def H(t,s,C,tps,tpt):
t = t.detach()
s = F.softmax(s/tps,dim=1)
t = F.softmax((t-C)/tpt,dim=1)
return -(t * torch.log(s)).sum(dim=1).mean()
def augement(x):
return x + 0.1*torch.randn_like(x)
batch_size = 64
#平滑softmax用的参数
tps = 0.1
tpt = 0.07
#l是lamda
l = 0.6
m = 0.5
optimizer = optim.SGD(gs.parameters(),lr =0.03,momentum=0.9,weight_decay=5e-4)
#training
loader = get_dataloader(batch_size)
for x, _ in loader: # We don't need labels for self-supervised learning
x = x.to(device) # (B,3,224,224)
x1, x2 = augment(x), augment(x) # random views
s1, s2 = gs(x1), gs(x2) # student output, (B,128)
t1, t2 = gt(x1), gt(x2) # teacher output, (B,128)
loss = H(t1, s2, C, tps, tpt)/2 + H(t2, s1, C, tps, tpt)/2 # divide by 2 for combined loss
loss.backward() # back-propagate
optimizer.step() # SGD update for student
optimizer.zero_grad() # Clear gradients
# Teacher and center updates
with torch.no_grad():
for teacher_param, student_param in zip(gt.parameters(), gs.parameters()):
teacher_param.data = l * teacher_param.data + (1 - l) * student_param.data
C = m * C + (1 - m) * torch.cat([t1, t2]).mean(dim=0)
print("Training complete!")