上下文感知元学习(Context-Aware Meta-Learning, CAML)
概述
在机器学习和深度学习领域,元学习(Meta-Learning)旨在通过学习如何学习,使模型能够在面对新任务时快速适应。传统的元学习方法通常需要在特定任务上进行微调,这不仅耗时,而且可能需要大量的计算资源。为了解决这一问题,斯坦福大学的研究团队提出了一种名为**上下文感知元学习(Context-Aware Meta-Learning, CAML)**的新算法。CAML通过在推理过程中学习新的视觉概念,而无需进行微调,模拟了大型语言模型(如ChatGPT)在推理过程中学习新概念的能力。
CAML算法的论文链接
本人空间-论文英文版pdf下载
背景与动机
元学习的挑战
传统的元学习方法通常依赖于在特定任务上进行微调,这不仅耗时,而且可能需要大量的计算资源。此外,这些方法在面对新任务时,往往表现不佳,尤其是在少样本学习(Few-Shot Learning)场景中。
大型语言模型的启发
近年来,大型语言模型(如ChatGPT)展示了在推理过程中学习新概念的出色能力。这些模型能够在无需微调的情况下,通过上下文学习新的概念。然而,类似的视觉模型在推理过程中学习新概念的能力却相对较弱,要么表现不佳,要么需要在类似对象上进行元训练或微调。
CAML的动机
CAML的提出旨在模拟大型语言模型的行为,通过在推理过程中学习新的视觉概念,而无需进行微调。这种方法不仅提高了模型的适应性,还减少了计算资源的消耗。
方法
序列建模
CAML将元学习任务建模为一个序列建模问题。输入序列由支持集图像和标签以及查询图像组成。支持集图像和标签用于提供上下文信息,而查询图像则是需要预测标签的图像。
Transformer编码器
CAML使用了一个基于Transformer编码器的序列模型。Transformer编码器被训练来从输入序列中提取特征,并预测查询图像的标签。Transformer的结构使其能够处理长距离依赖关系,从而更好地捕捉上下文信息。
等长最大角度集(ELMES)编码
支持集标签使用等长最大角度集(Equal-Length Maximum Angle Set, ELMES)编码进行编码。ELMES编码通过最小化检测支持集内类的熵,提高了模型的性能。实验结果表明,ELMES编码是最优的。
预训练的CLIP模型
CAML利用预训练的CLIP模型作为固定的特征提取器。CLIP模型在大规模图像和文本数据上进行了预训练,能够提取出丰富的视觉特征。CAML仅训练Transformer编码器的参数,而保持CLIP模型的参数不变。
# 使用timm中的预训练CLIP编码器,进行图像统一编码。
timm_feature_model_name, dim = "vit_base_patch16_clip_224.openai", 768
timm_feature_model_path = os.path.join(str(project_root()), "timm/vit_base_patch16_clip_224.openai/open_clip_pytorch_model.bin")
feature_extractor = timm.create_model(timm_feature_model_name,
pretrained=True,
pretrained_cfg_overlay=dict(file=timm_feature_model_path),
img_size=224,
num_classes=0).eval()
实验与结果
数据集
CAML在多个数据集上进行了实验,包括ImageNet、COCO等。这些数据集涵盖了不同的视觉任务,如分类、检测等。
基准测试
CAML在11个元学习基准测试中进行了评估。实验结果表明,在不进行任何元训练或微调的情况下,CAML在8个基准测试中匹配或超过了最先进的元学习方法。
性能分析
分析表明,CAML的表示会根据查询图像而改变,这使其能针对每个任务适当地选择支持集。经验结果也表明,ELMES编码是最优的。
优势
无需微调
CAML的最大优势在于无需进行微调。传统的元学习方法通常需要在特定任务上进行微调,这不仅耗时,而且可能需要大量的计算资源。CAML通过在推理过程中学习新的视觉概念,避免了这一问题。
def meta_infer(self, inp, support_labels, query_tensor):
"""For evaluating typical Meta-Learning Datasets."""
# 提取特征向量(假设 inp 是 support set)
support_features = self.get_feature_vector(inp) # 划分 support set
query_features = self.get_feature_vector(query_tensor) # 划分query set
# # 拼接两个特征,变成上下文。
support = support_features.unsqueeze(0)
query = query_features.unsqueeze(0)
feature_sequences = torch.cat([query, support], dim=1)
# 输入上下文, 使用 Transformer 编码器进行推理
logits = self.transformer_encoder.forward_imagenet_v2(feature_sequences, support_labels, way=None, shot=None)
# 添加 softmax 操作,输出概率分布
probabilities = torch.softmax(logits, dim=1)
_, max_index = torch.max(logits, 1)
# 返回概率分布和预测结果
return max_index, probabilities
高性能
在11个元学习基准测试中,CAML在8个测试中超过或与现有的元学习算法表现相当。这些算法都是在这些基准测试中进行了元训练的。
模拟大型语言模型的行为
CAML通过在推理过程中学习新的视觉概念,模拟了大型语言模型的行为。这不仅提高了模型的适应性,还减少了计算资源的消耗。
结论
CAML提出了一种无需微调的元学习算法,通过在推理过程中学习新的视觉概念,模拟了大型语言模型的行为,并在多个基准测试中取得了与现有算法相当的性能。CAML的成功不仅展示了元学习在视觉任务中的潜力,还为未来的研究提供了新的方向。
代码
关注本人博客,后续将基于作者的算法进行代码开发与实验讲解。