在对比学习(Contrastive Learning)中,NCE(Noise-Contrastive Estimation)和InfoNCE是两种常见的目标函数,它们都用于通过区分正样本和负样本来学习高质量的表示。
1. NCE(Noise-Contrastive Estimation)
- 定义:NCE最初是为无监督学习中的概率模型(如语言模型)设计的,目的是通过对比正样本和噪声样本(负样本)来估计概率分布。它将任务转化为一个二分类问题:区分真实数据(正样本)和噪声分布生成的样本(负样本)。
- 目标:最大化正样本的对数似然,同时最小化负样本的对数似然。形式上,NCE的目标函数可以写为:
L N C E = E p data ( x ) [ log σ ( f ( x ) ) ] + k ⋅ E p noise ( x ) [ log ( 1 − σ ( f ( x ) ) ) ] L_{NCE} = \mathbb{E}_{p_{\text{data}}(x)}[\log \sigma(f(x))] + k \cdot \mathbb{E}_{p_{\text{noise}}(x)}[\log (1 - \sigma(f(x)))] LNCE=Epdata(x)[logσ(f(x))]+k⋅Epnoise(x)[log(1−σ(f(x)))]
其中:- p data ( x ) p_{\text{data}}(x) pdata(x) 是真实数据分布;
- p noise ( x ) p_{\text{noise}}(x) pnoise(x) 是噪声分布;
- f ( x ) f(x) f(x) 是模型的评分函数(如神经网络输出);
- σ \sigma σ 是sigmoid函数;
- k k k 是负样本的数量。
- 核心思想:通过对比正样本和负样本,逼近真实数据分布的似然估计。负样本通常从一个预定义的噪声分布(如均匀分布或高斯分布)中采样。
- 特点:
- 更偏向于概率建模,适用于估计数据分布。
- 负样本的生成依赖于噪声分布的选择,质量和多样性可能受限。
- 计算复杂度与负样本数量成正比。
2. InfoNCE
- 定义:InfoNCE是NCE的一种变体,全称是“Information Noise-Contrastive Estimation”,广泛用于现代对比学习框架(如SimCLR、MoCo等)。它基于互信息(Mutual Information)的最大化,通过对比正样本和一组负样本学习表示。
- 目标:InfoNCE的目标是最大化正样本对的相似度,同时最小化与负样本的相似度。其形式通常为:
L I n f o N C E = − E [ log exp ( s ( x , x + ) / τ ) exp ( s ( x , x + ) / τ ) + ∑ i = 1 N exp ( s ( x , x i − ) / τ ) ] L_{InfoNCE} = -\mathbb{E} \left[ \log \frac{\exp(s(x, x^+)/\tau)}{\exp(s(x, x^+)/\tau) + \sum_{i=1}^{N} \exp(s(x, x_i^-)/\tau)} \right] LInfoNCE=−E[logexp(s(x,x+)/τ)+∑i=1Nexp(s(x,xi−)/τ)exp(s(x,x+)/τ)]
其中:- x x x 是锚点样本, x + x^+ x+ 是正样本, x i − x_i^- xi− 是负样本;
- s ( ⋅ , ⋅ ) s(\cdot, \cdot) s(⋅,⋅) 是相似度函数(如余弦相似度或点积);
- τ \tau τ 是温度参数,用于控制分布的平滑性;
- N N N 是负样本数量。
- 核心思想:通过softmax形式的分类任务,最大化锚点与正样本之间的互信息下界。负样本通常来自同一批次的数据(而不是预定义的噪声分布)。
- 特点:
- 更适合表示学习(representation learning),而非直接建模概率分布。
- 负样本通常从数据本身采样(如数据增强或批次中的其他样本),更贴近任务分布。
- 引入温度参数 τ \tau τ,可以调节模型对相似度的敏感度。
主要区别
方面 | NCE | InfoNCE |
---|---|---|
设计初衷 | 概率分布估计(如语言模型) | 表示学习(如图像、视频表征) |
负样本来源 | 预定义噪声分布(如均匀或高斯) | 数据本身(如批次内其他样本) |
目标 | 逼近似然估计 | 最大化互信息下界 |
数学形式 | 二分类形式(sigmoid) | 多分类形式(softmax) |
温度参数 | 无 | 有( τ \tau τ,控制分布平滑性) |
应用场景 | 传统无监督学习(如word2vec) | 现代对比学习(如SimCLR、MoCo) |
负样本数量 | 固定或较少 | 通常较多(如批次大小决定) |
联系与演化
- 联系:InfoNCE可以看作是NCE的扩展和改进。两者都基于对比的思想,即通过区分正样本和负样本优化模型。
- 演化:NCE更早提出,适用于概率建模任务;InfoNCE在深度学习时代被改进,结合了互信息的理论基础,更适合表示学习任务,尤其是在大规模数据和神经网络中。
- 如果任务是估计概率分布(如生成模型),NCE可能更合适。
- 如果目标是学习数据表示(如自监督学习中的特征提取),InfoNCE是更好的选择,因为它更灵活且与现代深度学习框架兼容性更高。
案例
“NCE 是 word2vec 中用于高效训练词嵌入的关键技术,它通过将多分类问题转化为二分类问题来避免 softmax 计算中的归一化项,大幅提高训练效率。”
假设你在找朋友“猫”常去的咖啡店:
- 原始方法(softmax):你得跑遍城里 10 万家咖啡店,算每家店“猫”去的可能性,最后挑出最可能的。
- NCE 方法:你只去“猫”常去的 1 家店(正样本),再随便挑 5 家“猫”从不去的店(负样本),问:“猫在这儿吗?”判断完就行了。
1. 什么是 word2vec 和词嵌入?
- 词嵌入:简单来说,就是把单词变成一串数字(向量),让计算机能理解。比如“猫”可能是 [0.1, 0.5, -0.2],“狗”可能是 [0.2, 0.4, -0.1],相似的词向量会更接近。
- word2vec:一个工具,用来从大量文本中学习这些词向量。它的工作方式是:给一个词(比如“猫”),预测它周围的词(比如“喵”或“宠物”)。
2. 为什么训练 word2vec 会有问题?
想象你在教一个模型预测“猫”旁边的词可能是“喵”。模型需要:
- 看整个词典(比如 10 万个词)。
- 给每个词打分,算出“喵”的概率最高。
- 用一个叫 softmax 的公式,把所有词的得分变成概率。
问题来了:如果词典有 10 万个词,每次预测都要算 10 万次得分,再把它们加起来(这就是“归一化项”),太慢了!就像你每次点餐都要把菜单上的 10 万道菜全看一遍,太费时间。
3. NCE 是什么?它怎么解决问题?
NCE(Noise-Contrastive Estimation,噪声对比估计)是 word2vec 的一个“加速器”。它不让你看整个菜单,而是用一个聪明的方法:
- 原来的任务(多分类):从 10 万个词里挑出“喵”作为“猫”的邻居。
- NCE 的新任务(二分类):只问一个简单问题:“这个词(比如‘喵’)是‘猫’的邻居吗?是/不是。”然后再随机挑几个“假邻居”(比如“桌子”“飞机”),问:“这些是‘猫’的邻居吗?是/不是。”
类比:就像考试从 10 万道选择题变成几道判断题。你不用算所有词的概率,只需要判断几个词的对错,简单多了。
4. “避免 softmax 计算中的归一化项”是什么意思?
- softmax:本来要算所有 10 万个词的得分总和(归一化项),才能知道“喵”的概率是多少。
- NCE:不算总和了!它只看“喵”和几个假词(比如“桌子”“飞机”),直接比较它们和“猫”的匹配度。这样就不用把 10 万个词加起来,省了很多计算。
举个例子:
- 原始方法:算“喵”“狗”“桌子”“飞机”……10 万个词的得分,再加起来,得出“喵”的概率。
- NCE 方法:只算“喵”和 5 个随机假词(比如“桌子”“飞机”),判断“喵”是真的邻居,其他是假的。
5. “大幅提高训练效率”
- 直观理解:原来算 10 万次,现在只算 6 次(1 个真词 + 5 个假词),速度快了几万倍!
- 结果:模型训练从几天变成几小时,甚至更快,能处理更大的词典和文本。
人工生成语料库代码解释
word_probs = 1.0 / torch.arange(1, vocab_size + 1, dtype=torch.float) ** 0.75
word_probs /= word_probs.sum()
corpus = torch.multinomial(word_probs, corpus_size, replacement=True)
1. word_probs = 1.0 / torch.arange(1, vocab_size + 1, dtype=torch.float) ** 0.75
- 做什么:这一行创建了一个词的“概率分布”,但不是均匀分布,而是根据词的排名(频率)调整的。
- 通俗理解:想象你有一个词典(比如 5 个词),每个词按频率排名:第 1 名最常见,第 5 名最不常见。这里用一个公式(倒数幂次)给每个词分配一个初始“重要性”分数,排名靠前的词分数更高,但差距不会太大。
- 细节:
torch.arange(1, vocab_size + 1, dtype=torch.float)
:生成一个从 1 到vocab_size
的序列,比如词典有 5 个词,就是[1, 2, 3, 4, 5]
。- ∗ ∗ 0.75 ** 0.75 ∗∗0.75:对每个数字取 0.75 次幂,比如 1 0.75 = 1 1^{0.75} = 1 10.75=1, 2 0.75 ≈ 1.68 2^{0.75} \approx 1.68 20.75≈1.68,这会让数字增长变慢,拉近高低排名的差距。
- 1.0 / . . . 1.0 / ... 1.0/...:取倒数,比如 $ 1/1 = 1 , , , 1/1.68 \approx 0.595 $,排名越靠后,分数越低。
2. word_probs /= word_probs.sum()
- 做什么:把上一行的分数变成真正的概率,确保它们加起来等于 1。
- 通俗理解:就像你把一堆筹码分给不同的人,但总筹码数要固定为 1。上一行算出的分数只是相对大小,这里把它们“标准化”,变成概率。
- 细节:用总和除以每个分数,保证是个合法的概率分布。
3. corpus = torch.multinomial(word_probs, corpus_size, replacement=True)
- 做什么:根据刚刚算出的概率,随机生成一个语料库(corpus),里面是词的索引。
- 通俗理解:想象一个抽奖转盘,每个词占一块区域,概率高的词区域大,被抽中的机会多。这里根据
word_probs
的概率,抽corpus_size
次,生成一串词的编号。 - 细节:
torch.multinomial
:一个抽样函数,按概率抽取。replacement=True
:表示“有放回”抽样,同一个词可以被重复抽中。- 输出是一个长度为
corpus_size
的张量,里面是词的索引(0 到vocab_size-1
)。
具体例子
假设我们设置:
vocab_size = 5
(词典里有 5 个词:[“猫”, “狗”, “鱼”, “鸟”, “鼠”])corpus_size = 10
(想生成一个 10 个词的语料库)
步骤 1:计算初始概率
word_probs = 1.0 / torch.arange(1, 6, dtype=torch.float) ** 0.75
torch.arange(1, 6)
生成[1, 2, 3, 4, 5]
。- 每个数取 0.75 次幂:
- 1 0.75 = 1 1^{0.75} = 1 10.75=1
- 2 0.75 ≈ 1.6818 2^{0.75} \approx 1.6818 20.75≈1.6818
- 3 0.75 ≈ 2.2795 3^{0.75} \approx 2.2795 30.75≈2.2795
- 4 0.75 ≈ 2.8284 4^{0.75} \approx 2.8284 40.75≈2.8284
- 5 0.75 ≈ 3.3437 5^{0.75} \approx 3.3437 50.75≈3.3437
- 取倒数:
- 1 / 1 = 1 1/1 = 1 1/1=1
- 1 / 1.6818 ≈ 0.5946 1/1.6818 \approx 0.5946 1/1.6818≈0.5946
- 1 / 2.2795 ≈ 0.4386 1/2.2795 \approx 0.4386 1/2.2795≈0.4386
- 1 / 2.8284 ≈ 0.3536 1/2.8284 \approx 0.3536 1/2.8284≈0.3536
- 1 / 3.3437 ≈ 0.2991 1/3.3437 \approx 0.2991 1/3.3437≈0.2991
- 结果:
word_probs = [1, 0.5946, 0.4386, 0.3536, 0.2991]
。
步骤 2:归一化成概率
word_probs /= word_probs.sum()
- 先算总和: 1 + 0.5946 + 0.4386 + 0.3536 + 0.2991 ≈ 2.6859 1 + 0.5946 + 0.4386 + 0.3536 + 0.2991 \approx 2.6859 1+0.5946+0.4386+0.3536+0.2991≈2.6859。
- 归一化:
- 1 / 2.6859 ≈ 0.3724 1 / 2.6859 \approx 0.3724 1/2.6859≈0.3724
- 0.5946 / 2.6859 ≈ 0.2214 0.5946 / 2.6859 \approx 0.2214 0.5946/2.6859≈0.2214
- 0.4386 / 2.6859 ≈ 0.1633 0.4386 / 2.6859 \approx 0.1633 0.4386/2.6859≈0.1633
- 0.3536 / 2.6859 ≈ 0.1317 0.3536 / 2.6859 \approx 0.1317 0.3536/2.6859≈0.1317
- 0.2991 / 2.6859 ≈ 0.1113 0.2991 / 2.6859 \approx 0.1113 0.2991/2.6859≈0.1113
- 结果:
word_probs = [0.3724, 0.2214, 0.1633, 0.1317, 0.1113]
。 - 检查:加起来是 1,符合概率要求。
步骤 3:生成语料库
corpus = torch.multinomial(word_probs, 10, replacement=True)
- 根据概率
[0.3724, 0.2214, 0.1633, 0.1317, 0.1113]
抽 10 次。 - 假设抽到的结果是:
[0, 2, 1, 0, 4, 3, 0, 1, 2, 0]
。 - 对应词:
["猫", "鱼", "狗", "猫", "鼠", "鸟", "猫", "狗", "鱼", "猫"]
。
为什么这样做?
这种方法模仿了自然语言中词频的分布(常见词多,罕见词少),常用于 word2vec 的负采样或语料生成。 0.75 0.75 0.75 次幂是个经验值,能平衡常见词和稀有词的出现频率。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
class ImprovedNCELoss(nn.Module):
"""
噪声对比估计(Noise Contrastive Estimation, NCE)损失函数实现
NCE是word2vec中用于高效训练词嵌入的关键技术,它通过将多分类问题转化为
二分类问题来避免softmax计算中的归一化项,大幅提高训练效率。
参数:
vocab_size (int): 词汇表大小
embedding_dim (int): 词嵌入维度
num_neg_samples (int): 每个正样本对应的负样本数量
noise_distribution (torch.Tensor): 噪声分布,默认为均匀分布
"""
def __init__(self, vocab_size, embedding_dim, num_neg_samples, noise_distribution=None):
super(ImprovedNCELoss, self).__init__()
# 初始化词嵌入矩阵
self.center_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)
# Xavier初始化以改善训练收敛性
nn.init.xavier_uniform_(self.center_embeddings.weight)
nn.init.xavier_uniform_(self.context_embeddings.weight)
self.num_neg_samples = num_neg_samples
# 如果没有提供噪声分布,则使用均匀分布
if noise_distribution is None:
self.register_buffer('noise_distribution', torch.ones(vocab_size))
else:
# 确保分布和为1
noise_distribution = noise_distribution / noise_distribution.sum()
self.register_buffer('noise_distribution', noise_distribution)
# 数值稳定性参数
self.eps = 1e-10
def forward(self, center_words, context_words):
"""
计算NCE损失
参数:
center_words (torch.Tensor): 形状为[batch_size]的中心词索引
context_words (torch.Tensor): 形状为[batch_size]的上下文词索引
返回:
loss (torch.Tensor): NCE损失值
pos_loss (torch.Tensor): 正样本损失
neg_loss (torch.Tensor): 负样本损失
accuracy (torch.Tensor): 二分类准确率
"""
batch_size = center_words.size(0)
device = center_words.device # 获取输入张量的设备
# 获取中心词和上下文词的嵌入表示
# [batch_size, embedding_dim]
center_embeds = self.center_embeddings(center_words)
context_embeds = self.context_embeddings(context_words)
# 计算正样本得分: 点积后应用Sigmoid
# [batch_size]
pos_scores = torch.sum(center_embeds * context_embeds, dim=1)
pos_probs = torch.sigmoid(pos_scores)
# 从噪声分布中采样负样本
# [batch_size, num_neg_samples]
neg_samples = torch.multinomial(
self.noise_distribution,
batch_size * self.num_neg_samples,
replacement=True
).view(batch_size, self.num_neg_samples).to(device) # 确保在正确的设备上
# 确保负样本不包含正样本
for i in range(batch_size):
mask = (neg_samples[i] == context_words[i])
if mask.any():
# 如果存在冲突,随机替换 - 确保替换索引在正确的设备上
replacement_indices = torch.randint(
0,
len(self.noise_distribution),
(mask.sum(),),
device=device # 确保在与其他张量相同的设备上
)
neg_samples[i, mask] = replacement_indices
# 获取负样本的嵌入表示
# [batch_size, num_neg_samples, embedding_dim]
neg_embeds = self.context_embeddings(neg_samples)
# 计算负样本得分
# [batch_size, num_neg_samples]
neg_scores = torch.bmm(
neg_embeds,
center_embeds.unsqueeze(2)
).squeeze(2)
neg_probs = torch.sigmoid(-neg_scores)
# 计算损失
# 正样本损失: -log(sigmoid(pos_score))
pos_loss = -torch.log(pos_probs + self.eps).mean()
# 负样本损失: -log(sigmoid(-neg_score))
neg_loss = -torch.log(neg_probs + self.eps).mean()
# 总损失
loss = pos_loss + neg_loss
# 计算分类准确率(用于监控训练进度)
with torch.no_grad():
pos_correct = (pos_probs > 0.5).float().sum()
neg_correct = (neg_probs > 0.5).float().sum()
total_samples = batch_size + batch_size * self.num_neg_samples
accuracy = (pos_correct + neg_correct) / total_samples
return loss, pos_loss, neg_loss, accuracy
def generate_synthetic_data(vocab_size, corpus_size, window_size=2):
"""
生成合成语料库及其中心词-上下文对
参数:
vocab_size (int): 词汇表大小
corpus_size (int): 语料库大小(词数)
window_size (int): 窗口大小
返回:
corpus (torch.Tensor): 语料库
center_words (torch.Tensor): 中心词索引列表
context_words (torch.Tensor): 上下文词索引列表
word_freqs (torch.Tensor): 词频分布
"""
# 生成遵循Zipf分布的随机语料库
word_probs = 1.0 / torch.arange(1, vocab_size + 1, dtype=torch.float) ** 0.75
word_probs /= word_probs.sum()
corpus = torch.multinomial(word_probs, corpus_size, replacement=True)
# 准备中心词和上下文词对
center_words, context_words = [], []
# 对每个位置生成中心词-上下文对
for i in range(corpus_size):
center_word = corpus[i]
# 定义上下文窗口范围
context_start = max(0, i - window_size)
context_end = min(corpus_size, i + window_size + 1)
# 收集上下文词
for j in range(context_start, context_end):
if i != j: # 排除中心词本身
center_words.append(center_word)
context_words.append(corpus[j])
# 计算词频
word_freqs = torch.zeros(vocab_size)
for word in corpus:
word_freqs[word] += 1
word_freqs = word_freqs / word_freqs.sum()
return (
corpus,
torch.tensor(center_words),
torch.tensor(context_words),
word_freqs
)
def train_word_embeddings():
"""
使用NCE损失函数训练词嵌入的完整流程
"""
# 设置超参数
vocab_size = 5000
embedding_dim = 100
num_neg_samples = 10
learning_rate = 0.001
batch_size = 512
epochs = 50
corpus_size = 10000
window_size = 2
# 选择设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 生成合成数据
print("生成合成语料库...")
corpus, center_words, context_words, word_freqs = generate_synthetic_data(
vocab_size, corpus_size, window_size
)
# 将数据移动到正确的设备上
corpus = corpus.to(device)
center_words = center_words.to(device)
context_words = context_words.to(device)
word_freqs = word_freqs.to(device)
# 创建数据加载器
dataset_size = len(center_words)
num_batches = (dataset_size - 1) // batch_size + 1
# 初始化模型和优化器
model = ImprovedNCELoss(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
num_neg_samples=num_neg_samples,
noise_distribution=word_freqs ** 0.75 # 使用平滑的词频分布作为噪声分布
).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=1
)
# 记录训练过程
training_stats = {
'total_loss': [],
'pos_loss': [],
'neg_loss': [],
'accuracy': [],
'time_per_epoch': []
}
print(f"开始训练,共{epochs}轮...")
# 训练循环
for epoch in range(epochs):
epoch_start_time = time.time()
epoch_total_loss = 0.0
epoch_pos_loss = 0.0
epoch_neg_loss = 0.0
epoch_accuracy = 0.0
# 随机打乱数据
indices = torch.randperm(dataset_size, device=device) # 确保在正确的设备上
# 使用tqdm创建进度条
pbar = tqdm(range(num_batches), desc=f"Epoch {epoch + 1}/{epochs}")
for batch_idx in pbar:
# 获取当前批次的索引
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, dataset_size)
batch_indices = indices[start_idx:end_idx]
# 获取中心词和上下文词批次
batch_center = center_words[batch_indices]
batch_context = context_words[batch_indices]
# 前向传播和反向传播
optimizer.zero_grad()
loss, pos_loss, neg_loss, accuracy = model(batch_center, batch_context)
loss.backward()
# 梯度裁剪以防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer.step()
# 累加统计数据
batch_size_actual = end_idx - start_idx
epoch_total_loss += loss.item() * batch_size_actual
epoch_pos_loss += pos_loss.item() * batch_size_actual
epoch_neg_loss += neg_loss.item() * batch_size_actual
epoch_accuracy += accuracy.item() * batch_size_actual
# 更新进度条
pbar.set_postfix({
'loss': f"{loss.item():.4f}",
'acc': f"{accuracy.item():.4f}"
})
# 计算每轮的平均统计数据
epoch_total_loss /= dataset_size
epoch_pos_loss /= dataset_size
epoch_neg_loss /= dataset_size
epoch_accuracy /= dataset_size
epoch_time = time.time() - epoch_start_time
# 保存训练统计数据
training_stats['total_loss'].append(epoch_total_loss)
training_stats['pos_loss'].append(epoch_pos_loss)
training_stats['neg_loss'].append(epoch_neg_loss)
training_stats['accuracy'].append(epoch_accuracy)
training_stats['time_per_epoch'].append(epoch_time)
print(f"Epoch {epoch + 1}/{epochs} - "
f"Loss: {epoch_total_loss:.4f} "
f"(Pos: {epoch_pos_loss:.4f}, Neg: {epoch_neg_loss:.4f}) - "
f"Accuracy: {epoch_accuracy:.4f} - "
f"Time: {epoch_time:.2f}s")
# 学习率调度
scheduler.step(epoch_total_loss)
print("训练完成!")
# 可视化训练过程
visualize_training(training_stats)
# 可视化词嵌入
visualize_embeddings(model, corpus, 50)
return model, training_stats
def visualize_training(stats):
"""
可视化训练过程
参数:
stats (dict): 包含训练统计数据的字典
"""
plt.figure(figsize=(12, 8))
# 绘制损失曲线
plt.subplot(2, 2, 1)
plt.plot(stats['total_loss'], 'b-', label='Total Loss')
plt.plot(stats['pos_loss'], 'g--', label='Positive Loss')
plt.plot(stats['neg_loss'], 'r--', label='Negative Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('NCE Loss During Training')
plt.legend()
plt.grid(True)
# 绘制准确率曲线
plt.subplot(2, 2, 2)
plt.plot(stats['accuracy'], 'g-')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Binary Classification Accuracy')
plt.grid(True)
# 绘制每轮训练时间
plt.subplot(2, 2, 3)
plt.bar(range(1, len(stats['time_per_epoch']) + 1), stats['time_per_epoch'])
plt.xlabel('Epoch')
plt.ylabel('Time (seconds)')
plt.title('Training Time per Epoch')
plt.tight_layout()
plt.show()
def visualize_embeddings(model, corpus, top_n=50):
"""
使用t-SNE可视化词嵌入
参数:
model (ImprovedNCELoss): 训练好的模型
corpus (torch.Tensor): 语料库
top_n (int): 要可视化的高频词数量
"""
try:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
# 获取词嵌入
embeddings = model.center_embeddings.weight.detach().cpu().numpy()
# 计算词频
word_counts = {}
for word_idx in corpus.cpu(): # 确保在CPU上处理
word_idx = word_idx.item()
word_counts[word_idx] = word_counts.get(word_idx, 0) + 1
# 选取出现频率最高的词
top_words = sorted(word_counts.items(), key=lambda x: x[1], reverse=True)[:top_n]
top_word_indices = [word_idx for word_idx, _ in top_words]
# 使用t-SNE降维
tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(top_word_indices) - 1))
reduced_embeddings = tsne.fit_transform(embeddings[top_word_indices])
# 绘制嵌入空间
plt.figure(figsize=(10, 8))
plt.scatter(reduced_embeddings[:, 0], reduced_embeddings[:, 1], s=50, c='steelblue', alpha=0.7)
# 添加词索引标签
for i, word_idx in enumerate(top_word_indices):
plt.annotate(f"word_{word_idx}",
(reduced_embeddings[i, 0], reduced_embeddings[i, 1]),
fontsize=9)
plt.title(f"t-SNE Visualization of Top {top_n} Word Embeddings")
plt.xlabel("t-SNE Dimension 1")
plt.ylabel("t-SNE Dimension 2")
plt.tight_layout()
plt.show()
except ImportError:
print("无法导入scikit-learn,跳过嵌入可视化")
def similarity_analysis(model, word_idx1, word_idx2):
"""
分析两个词的相似度
参数:
model (ImprovedNCELoss): 训练好的模型
word_idx1, word_idx2 (int): 要比较的词的索引
"""
# 获取词嵌入
embed1 = model.center_embeddings.weight[word_idx1].detach().cpu()
embed2 = model.center_embeddings.weight[word_idx2].detach().cpu()
# 计算余弦相似度
cos_sim = torch.nn.functional.cosine_similarity(embed1.unsqueeze(0), embed2.unsqueeze(0))
print(f"词 {word_idx1} 和词 {word_idx2} 的余弦相似度: {cos_sim.item():.4f}")
def find_similar_words(model, word_idx, top_k=5):
"""
找出与给定词最相似的top_k个词
参数:
model (ImprovedNCELoss): 训练好的模型
word_idx (int): 目标词索引
top_k (int): 返回的相似词数量
"""
# 获取所有词嵌入
all_embeddings = model.center_embeddings.weight.detach().cpu()
target_embedding = all_embeddings[word_idx]
# 计算余弦相似度
similarities = torch.nn.functional.cosine_similarity(
target_embedding.unsqueeze(0), all_embeddings)
# 获取最相似的词(排除自身)
similarities[word_idx] = -1 # 排除自身
top_similar = torch.topk(similarities, k=top_k)
print(f"与词 {word_idx} 最相似的 {top_k} 个词:")
for i, (idx, sim) in enumerate(zip(top_similar.indices, top_similar.values)):
print(f" {i + 1}. 词 {idx.item()}: 相似度 {sim.item():.4f}")
if __name__ == "__main__":
# 设置随机种子以确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
try:
# 训练模型
model, stats = train_word_embeddings()
# 示例:分析一些词的相似度
print("\n词相似度分析示例:")
similarity_analysis(model, 10, 20)
# 示例:寻找相似词
print("\n寻找相似词示例:")
find_similar_words(model, 5, top_k=5)
except Exception as e:
print(f"出现错误: {e}")
# 如果是CUDA相关错误,提供使用CPU的选项
if "CUDA" in str(e):
print("\n提示: 如果您的GPU内存不足或无法使用GPU,请修改代码改用CPU:")
print("将 device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")")
print("替换为 device = torch.device(\"cpu\")")
InfoNCE(SimCLR)实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torch.optim import Adam
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
# ------- SimCLR 数据增强 -------
class SimCLRTransform:
def __init__(self, size=32):
self.base_transform = transforms.Compose([
transforms.RandomResizedCrop(size=size, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=3),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.247, 0.243, 0.261])
])
def __call__(self, x):
return self.base_transform(x), self.base_transform(x)
# ------- 编码器 -------
class Encoder(nn.Module):
def __init__(self):
super().__init__()
resnet = torchvision.models.resnet18(weights=None) # 替代 pretrained=False
self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])
self.feature_dim = resnet.fc.in_features
def forward(self, x):
x = self.feature_extractor(x)
return x.view(x.size(0), -1)
# ------- 投影头 -------
class ProjectionHead(nn.Module):
def __init__(self, in_dim=512, hidden_dim=2048, out_dim=128):
super().__init__()
self.projection = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, out_dim),
nn.BatchNorm1d(out_dim)
)
def forward(self, x):
return self.projection(x)
# ------- InfoNCE 损失 -------
def info_nce_loss(z1, z2, temperature=0.5):
batch_size = z1.shape[0]
z = torch.cat([z1, z2], dim=0)
z = F.normalize(z, dim=1)
sim_matrix = torch.matmul(z, z.T) / temperature
labels = torch.arange(batch_size).to(z1.device)
labels = torch.cat([labels, labels], dim=0)
positive_mask = F.one_hot(labels, num_classes=2 * batch_size).float()
positive_mask = torch.roll(positive_mask, shifts=batch_size, dims=1)
logits_mask = 1 - torch.eye(2 * batch_size, device=z.device)
exp_logits = torch.exp(sim_matrix) * logits_mask
log_prob = sim_matrix - torch.log(exp_logits.sum(1, keepdim=True) + 1e-8)
mean_log_prob_pos = (positive_mask * log_prob).sum(1) / positive_mask.sum(1)
loss = -mean_log_prob_pos.mean()
return loss
# ------- SimCLR 训练 -------
def train_simclr(batch_size=256, epochs=100, temperature=0.5, device='cuda'):
transform = SimCLRTransform()
train_dataset = CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
encoder = Encoder().to(device)
projection = ProjectionHead(encoder.feature_dim).to(device)
optimizer = Adam(list(encoder.parameters()) + list(projection.parameters()), lr=1e-3)
losses = []
for epoch in range(epochs):
encoder.train()
projection.train()
epoch_loss = 0
for (x1, x2), _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
x1, x2 = x1.to(device), x2.to(device)
h1, h2 = encoder(x1), encoder(x2)
z1, z2 = projection(h1), projection(h2)
loss = info_nce_loss(z1, z2, temperature)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(train_loader)
losses.append(avg_loss)
print(f"[Epoch {epoch+1}] Loss: {avg_loss:.4f}")
# 保存损失曲线图
plt.figure()
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("InfoNCE Loss")
plt.title("SimCLR Training Loss")
plt.savefig("simclr_loss_curve.png")
print("✅ Loss 曲线已保存为 simclr_loss_curve.png")
return encoder
# ------- Linear Probe 线性分类器 -------
def linear_probe(encoder, device='cuda'):
encoder.eval()
for param in encoder.parameters():
param.requires_grad = False
# 用简单 transform 获取特征
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.247, 0.243, 0.261])
])
train_dataset = CIFAR10(root='./data', train=True, transform=test_transform, download=True)
test_dataset = CIFAR10(root='./data', train=False, transform=test_transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
classifier = nn.Linear(encoder.feature_dim, 10).to(device)
optimizer = Adam(classifier.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
# 训练线性分类器
for epoch in range(10):
classifier.train()
total_loss = 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
with torch.no_grad():
feat = encoder(x)
pred = classifier(feat)
loss = criterion(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"[Linear Probe Epoch {epoch+1}] Loss: {total_loss/len(train_loader):.4f}")
# 测试集准确率
classifier.eval()
correct = 0
total = 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
feat = encoder(x)
pred = classifier(feat)
correct += (pred.argmax(1) == y).sum().item()
total += y.size(0)
acc = 100 * correct / total
print(f"✅ Linear Probe Accuracy: {acc:.2f}%")
return acc
# ------- 特征可视化 -------
def visualize_tsne(encoder, device='cuda'):
encoder.eval()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.247, 0.243, 0.261])
])
dataset = CIFAR10(root='./data', train=False, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=256, shuffle=False)
features, labels = [], []
with torch.no_grad():
for x, y in tqdm(loader, desc="Extracting features for t-SNE"):
x = x.to(device)
f = encoder(x)
features.append(f.cpu())
labels.append(y)
features = torch.cat(features, dim=0).numpy()
labels = torch.cat(labels, dim=0).numpy()
tsne = TSNE(n_components=2, perplexity=30, init='pca', random_state=0)
embedding = tsne.fit_transform(features)
plt.figure(figsize=(10, 8))
scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels, cmap='tab10', alpha=0.6)
plt.legend(*scatter.legend_elements(), title="Classes")
plt.title("t-SNE of SimCLR Features")
plt.savefig("simclr_tsne.png")
print("✅ t-SNE 图已保存为 simclr_tsne.png")
# ------- 主函数 -------
if __name__ == "__main__":
device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder = train_simclr(batch_size=256, epochs=50, temperature=0.5, device=device)
linear_probe(encoder, device=device)
visualize_tsne(encoder, device=device)
SimSiam案例
SimSiam概述
SimSiam架构的流程:从一个原始图像 x x x 生成两个增强视图 x 1 x_1 x1 和 x 2 x_2 x2,然后通过特定的网络结构处理这两个视图,最终学习到图像的有意义的表示。图中包含以下关键组件:
- 两个相同的编码器网络(encoder f f f):用蓝色矩形表示,处理 x 1 x_1 x1 和 x 2 x_2 x2。
- 预测器(predictor h h h):用橙色小矩形表示,作用于 x 1 x_1 x1 的输出。
- 停止梯度(stop-grad)操作:应用于 x 2 x_2 x2 的输出,阻止梯度反向传播。
- 相似性(similarity):通过箭头表示,目标是最大化预测器输出和停止梯度后的输出之间的相似性。
SimSiam架构的详细解释
1. 输入图像
- 图中显示了两个增强视图 x 1 x_1 x1 和 x 2 x_2 x2,它们是从同一个原始图像 x x x 通过随机变换(如裁剪、旋转、颜色变换等)生成的。
- 自监督学习的核心思想是利用这些增强视图之间的关系,让模型学习图像的内在表示。
2. 编码器网络(encoder f f f)
- 组成:
- 骨干网络(backbone):通常是一个卷积神经网络(如ResNet),用于提取图像的特征。
- 投影多层感知机(projection MLP):一个小型神经网络,将骨干网络的输出映射到一个低维嵌入空间。
- 功能:两个编码器 f f f 分别处理 x 1 x_1 x1 和 x 2 x_2 x2,生成它们的嵌入表示。这两个编码器是相同的,共享参数。
- 图示:用蓝色矩形表示,分别标注为 f f f。
3. 预测器(predictor h h h)
- 组成:一个小型的多层感知机(MLP)。
- 功能:作用于编码器 f f f 对 x 1 x_1 x1 的输出(即 f ( x 1 ) f(x_1) f(x1)),生成一个预测表示,试图预测 x 2 x_2 x2 的表示。
- 图示:用橙色小矩形表示,连接在 x 1 x_1 x1 的编码器输出之后。
4. 停止梯度(stop-grad)
- 功能:应用于编码器 f f f 对 x 2 x_2 x2 的输出(即 f ( x 2 ) f(x_2) f(x2)),阻止梯度通过这个路径反向传播。这意味着在训练时,编码器 f f f 的参数不会通过 x 2 x_2 x2 的路径更新。
- 图示:用一个类似“x”的符号和向下箭头表示,位于 x 2 x_2 x2 的编码器输出之后。
5. 相似性(similarity)
- 功能:图中有一个标有“similarity”的箭头,连接预测器 h h h 的输出(基于 x 1 x_1 x1)和停止梯度后的 f ( x 2 ) f(x_2) f(x2)。模型的目标是最大化这两者之间的相似性。
- 实现:通常使用余弦相似度作为损失函数,鼓励两个表示向量之间的夹角变小。
SimSiam的工作原理
SimSiam通过以下步骤学习图像的表示:
- 从原始图像 x x x 生成两个增强视图 x 1 x_1 x1 和 x 2 x_2 x2。
- 将 x 1 x_1 x1 和 x 2 x_2 x2 输入两个相同的编码器 f f f,分别得到嵌入表示 f ( x 1 ) f(x_1) f(x1) 和 f ( x 2 ) f(x_2) f(x2)。
- 对 f ( x 1 ) f(x_1) f(x1) 应用预测器 h h h,生成预测表示 h ( f ( x 1 ) ) h(f(x_1)) h(f(x1))。
- 对 f ( x 2 ) f(x_2) f(x2) 应用停止梯度操作,得到 stop-grad ( f ( x 2 ) ) \text{stop-grad}(f(x_2)) stop-grad(f(x2))。
- 模型通过最小化 h ( f ( x 1 ) ) h(f(x_1)) h(f(x1)) 和 stop-grad ( f ( x 2 ) ) \text{stop-grad}(f(x_2)) stop-grad(f(x2)) 之间的差异(即最大化它们的相似性),学习对图像内容不变的表示。
stop-grad
停止梯度操作(stop-gradient operation)是一种在神经网络训练中使用的技术,用来阻止梯度在反向传播时流向某些特定的路径。在深度学习框架中,这通常通过特定的函数实现,比如在 PyTorch 中使用 detach()
,在 TensorFlow 中使用 tf.stop_gradient()
。
1. 前向传播:stop-grad(f(x₂))=f(x₂)
- 在前向传播(也就是从输入到输出的计算过程)中,
stop-grad(f(x₂))
的值就是f(x₂)
。这里的f
通常是一个编码器(比如神经网络),x₂
是输入数据,f(x₂)
是编码器对x₂
的输出。 stop-grad
只是一个操作,它不会改变f(x₂)
的数值结果。所以,如果你在损失函数中用到了stop-grad(f(x₂))
,比如计算h(f(x₁))
和stop-grad(f(x₂))
之间的差异,实际上就是在计算h(f(x₁))
和f(x₂)
的差异。- 简单来说,前向传播时,
stop-grad
“看不见”它的作用,它就像一个透明的传递层。
2. 反向传播:stop-grad 的关键作用
- 在反向传播(也就是计算梯度并更新参数的过程)中,
stop-grad
的作用就显现出来了:它会阻止梯度通过f(x₂)
回传到编码器f
。 - 具体来说,当你计算损失函数关于模型参数的梯度时,
stop-grad(f(x₂))
被视为一个常量。常量在求导时梯度为零,所以梯度不会流回f
,也不会影响f
的参数更新。 - 举个例子:
- 假设损失函数是 L = ∣ ∣ h ( f ( x 1 ) ) − s t o p − g r a d ( f ( x 2 ) ) ∣ ∣ 2 L=||h(f(x₁))-stop-grad(f(x₂))||² L=∣∣h(f(x1))−stop−grad(f(x2))∣∣2。
- 前向传播时,计算的是 h ( f ( x 1 ) ) h(f(x₁)) h(f(x1)) 和 f ( x 2 ) f(x₂) f(x2) 的差。
- 反向传播时,梯度会流向 h h h 和 f ( x 1 ) f(x₁) f(x1),但因为
stop-grad
,梯度不会流向 f ( x 2 ) f(x₂) f(x2) 或编码器 f f f 在处理 x 2 x₂ x2 时的参数。
3. 为什么会有这种区别?
- 你可能会问:既然前向传播时
stop-grad(f(x₂))
就是f(x₂)
,为什么不在损失函数里直接用 f ( x 2 ) f(x₂) f(x2)? - 答案在于模型训练的目标。如果直接用 f ( x 2 ) f(x₂) f(x2),梯度会同时流向 f ( x 1 ) f(x₁) f(x1) 和 f ( x 2 ) f(x₂) f(x2),模型可能会“作弊”:让 f ( x 1 ) f(x₁) f(x1) 和 f ( x 2 ) f(x₂) f(x2) 输出相同的值(比如全零向量),这样损失会很小,但这种表示没有意义,因为所有输入都被映射成了同一个点。
- 使用
stop-grad(f(x₂))
,梯度只流向 f ( x 1 ) f(x₁) f(x1) 和 h h h,迫使模型调整 h ( f ( x 1 ) ) h(f(x₁)) h(f(x1)) 去匹配 f ( x 2 ) f(x₂) f(x2),而 f ( x 2 ) f(x₂) f(x2) 保持不变(因为没有梯度更新)。这能让编码器 f f f 学到更有意义的表示,而不是简单的坍塌解。
- s t o p − g r a d ( f ( x 2 ) ) stop-grad(f(x₂)) stop−grad(f(x2)) 在前向传播时就是 f ( x 2 ) f(x₂) f(x2),数值上完全一样。
- 但在反向传播时,
stop-grad
阻止了梯度回传到 f ( x 2 ) f(x₂) f(x2),这使得模型的参数更新只依赖于 h ( f ( x 1 ) ) h(f(x₁)) h(f(x1)) 的调整,而不是直接改变 f ( x 2 ) f(x₂) f(x2)。 - 这种设计是为了避免模型学到退化的解,确保编码器 f f f 能提取出对输入数据有意义的特征。
SimSiam代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torch.optim import SGD
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
# -------------------------------
# 1. 数据增强模块(两个视图)
# -------------------------------
class SimSiamTransform:
"""SimSiam 的数据增强模块,生成两个视图(类似 SimCLR)"""
def __init__(self, image_size=32):
self.transform = transforms.Compose([
transforms.RandomResizedCrop(image_size, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
])
def __call__(self, x):
return self.transform(x), self.transform(x)
# -------------------------------
# 2. Encoder + Projection + Prediction Head
# -------------------------------
class MLPHead(nn.Module):
"""投影/预测头,Linear -> BN -> ReLU -> Linear"""
def __init__(self, in_dim, hidden_dim=2048, out_dim=2048):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, out_dim)
)
def forward(self, x):
return self.net(x)
class SimSiam(nn.Module):
"""SimSiam 主网络:encoder -> projection -> prediction"""
def __init__(self, backbone_dim=512):
super().__init__()
# 预训练 backbone(去掉最后全连接层)
backbone = models.resnet18(weights=None) # 不加载预训练权重
self.encoder = nn.Sequential(*list(backbone.children())[:-1]) # 去除分类层
# 投影头 g()
self.projector = MLPHead(backbone_dim, 2048, 2048)
# 预测头 h()
self.predictor = MLPHead(2048, 512, 2048)
def forward(self, x1, x2):
# 提取两个视图的特征
f1 = self.encoder(x1).flatten(start_dim=1) # [B, 512]
f2 = self.encoder(x2).flatten(start_dim=1)
z1 = self.projector(f1) # g(f(x1))
z2 = self.projector(f2)
p1 = self.predictor(z1) # h(g(f(x1)))
p2 = self.predictor(z2)
return p1, p2, z1.detach(), z2.detach() # stop-grad 作用于 z1/z2
# -------------------------------
# 3. SimSiam 损失函数(我们希望两个特征向量(预测结果和目标结果)越接近越好,也就是它们之间的余弦相似度越高,最好能接近 1。由于我们使用的优化器都是“最小化”损失的,所以我们把“相似度越高(越好)”转变成“损失越低”。把余弦相似度取负:当两个向量很相似时,余弦相似度接近 1,负值就是 -1;而如果不相似时,负值会更高(例如 0 或正值)因此,最小化这个负值就能让模型学到让两个向量更相似的表示。)
# -------------------------------
def D(p, z):
"""负向 cosine similarity"""
p = F.normalize(p, dim=1)
z = F.normalize(z, dim=1)
return - (p * z).sum(dim=1).mean()
def simsiam_loss(p1, p2, z1, z2):
return D(p1, z2) / 2 + D(p2, z1) / 2
# -------------------------------
# 4. 训练过程
# -------------------------------
def train_simsiam(batch_size=512, epochs=50, lr=0.05):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = SimSiamTransform()
dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)
model = SimSiam().to(device)
optimizer = SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
loss_list = []
for epoch in range(1, epochs + 1):
model.train()
total_loss = 0
for (x1, x2), _ in tqdm(loader, desc=f"Epoch {epoch}/{epochs}"):
x1, x2 = x1.to(device), x2.to(device)
p1, p2, z1, z2 = model(x1, x2)
loss = simsiam_loss(p1, p2, z1, z2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(loader)
loss_list.append(avg_loss)
print(f"[Epoch {epoch}] Loss: {avg_loss:.4f}")
# 保存训练损失图
os.makedirs("output", exist_ok=True)
plt.plot(loss_list)
plt.title("SimSiam Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.savefig("output/simsiam_loss_curve.png")
plt.close()
print("训练完成,损失图已保存到 output/simsiam_loss_curve.png")
return model
if __name__ == "__main__":
train_simsiam()