原文链接:[2106.14413] Co$^2$L: Contrastive Continual Learning
阅读本文前,需要对持续学习的基本概念以及面临的问题有大致了解,可参考综述:
Wang L, Zhang X, Su H, et al. A comprehensive survey of continual learning: Theory, method and application[J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2024.
一、概述
该论文提出了一种名为 Co²L(对比持续学习) 的新方法,旨在解决持续学习中的灾难性遗忘问题。核心思想是通过对比学习(Contrastive Learning)和自监督蒸馏(Self-supervised Distillation)来持续学习和维护可迁移的表示。
1、对比学习
论文中提到了对比学习在持续学习中比传统的联合训练方法更能抵抗灾难性遗忘。要理解这句话,需要先弄懂对比学习和联合训练的策略分别是什么。
(1)对比学习:一种自监督学习方法,其核心思想是通过比较样本之间的相似性与差异性来学习数据的本质特征。核心机制是正负样本对的构建,一般对比学习中,会将一个数据样本的不同数据增强(旋转、缩放、颜色变换、剪切等)版本作为它的正样本,将不同数据样本作为负样本,学习过程中强制正样本之间的特征表达尽量接近,负样本尽量远离。在持续学习场景中,会将旧任务的样本全部视作负样本,以此让模型更好学习新任务的表达。
(2)联合训练:在本文所提到的持续学习场景下,所谓的联合训练实际就是传统的有监督学习,只是会在训练新任务时,回放一些旧任务的数据一并训练,以及采用正则化等手段来限制参数更新对旧任务的影响。
我们可以发现,对比学习通过数据增强,学习到的是更通用的特征,比如形状、纹理等,这些特征在不同任务间具有高度共享性,支持新任务的快速适应。而联合训练依赖特定监督标签,导致新任务学习时容易覆盖旧任务的敏感特征,引发遗忘。
这么说比较抽象,结合下图从损失函数的角度来分析更方便理解,对比学习目标是学习对数据变换不变的表示,可能通过数据增强和对比损失函数促进参数空间的平滑性,对应的损失函数比较平坦。而联合训练由于要同时优化多个任务的损失,可能导致不同任务之间的梯度方向冲突,容易形成尖锐的极小值。很显然,平坦的损失函数更有利于我们在参数空间寻找到一个平衡点,保证对新旧任务的整体影响最小,从而缓解灾难性遗忘问题。
经过上述分析,我们可以初步下结论,在持续学习中使用对比学习有利于抵抗灾难性遗忘问题。
但是传统对比学习在持续学习中存在一个问题,就是我们只保留少量的旧任务样本,这导致负样本的数量以及多样性不足,会影响对比学习效果。
2、知识蒸馏
知识蒸馏通常用于将“教师模型”的知识迁移到“学生模型”,传统蒸馏方式包括输出蒸馏,即让学生模型输出的概率分尽量接近于教师模型,通过交叉熵损失实现:
还有特征蒸馏,即选择教师模型与学生模型的某一中间层,目标是最小化中间层的特征表示,通过L2损失或余弦相似度实现:
输出蒸馏依赖于输出概率分布,但是在持续学习中,往往只保留少量旧任务数据,导致模型难以准确复现旧任务的输出概率,所以输出蒸馏不适合持续学习。
特征蒸馏需要对中间层的特征进行对齐,但在持续学习中,随着新任务的训练,模型的结构是在不断调整的,中间层的特征空间也会改变,这就导致我们无法稳定的用旧任务的特征表示来对新任务的特征表示进行蒸馏,所以特征蒸馏也不适合持续学习。
二、贡献点
为了解决对比学习在持续学习中负样本不足的问题,提出了非对称监督对比损失(Asymmetric SupCon)的解决方案。
为了解决传统知识蒸馏方式在持续学习场景下不适用的问题,提出了一种实例关系蒸馏的自监督蒸馏方式。
三、方法
1、非对称监督对比损失(Asymmetric Supervised Contrastive Loss)
核心思想很简单,原本的对比学习会将旧任务和新任务的数据样本一视同仁,每个样本都有参与到正样本对和负样本对的构建,而本文提出,我们只用新任务的数据样本构建正样本对,所有的旧任务数据都只作为负样本。这样能让模型更专注于新任务,更好的学习到新旧任务的边界。
(举个例子,假设旧任务有“猫”、“狗”、“鱼”,新任务是“鸟”,如果是传统对比学习,我们就需要让模型同时学习到四个类别的特征,只不过前三个类别因为已经学习过,现在学起来更容易,但存在问题是前三个类别的样本数目很少,容易在前三个类别上出现过拟合。而新的思路是,我既然已经能识别出前三类,那我只需要学习“鸟”的特征,并保证“鸟”的特征与前三类特征尽量远离。)
损失函数形式如下,S表示当前任务的样本的集合,对于锚点样本i,p_i是正样本对的集合,z_i和z_p是锚点样本与正样本经过投影后的特征向量,两者的点积表示余弦相似度。k表示的是负样本,所以分母是锚点样本与所有负样本之间的相似度之和。
2、实例关系蒸馏(Instance-wise Relation Distillation, IRD)
目标是通过维护样本间的相似性结构来减少持续学习中的灾难性遗忘。IRD特点是IRD 不直接对齐单个样本的特征或分类结果,而是关注样本间的全局关系结构(如“猫”和“狗”应彼此接近,“猫”和“鸟”应远离),从而保留旧任务的特征空间结构。
(1)参考模型保存
在完成第 t−1t−1 个任务后,将模型的参数(包括编码器 和投影头
)保存为参考模型
,并冻结其参数。
(2)实例相似性向量计算
对于一个包含 2N 个增强样本的批次 B,每个样本 x~i 的实例相似性向量定义如下:
其中
(3)蒸馏损失函数
通过交叉熵损失对齐参考模型与当前模型的相似性分布:
四、实验结果
在各类增量学习基准中都取得明显提升
通过消融实验证明了非对称监督对比损失和实例关系蒸馏的有效性。
五、总结
本文通过改进后的对比学习强化了模型适应新任务的能力,同时,通过实例关系蒸馏,让模型能保留旧任务上的性能,缓解灾难性遗忘问题。