文章目录
摘要
MM-CDFSL是一个针对第一人称动作识别、使用跨域小样本学习的模型,它的核心创新点体现在多模态协同学习与跨域动态推理机制的深度融合。该方法首次将光流、手部姿态热图等多模态教师模型的知识蒸馏引入RGB学生模型,通过未标注目标域数据实现跨模态特征对齐,使单模态推理模型能够继承多源信息对光照、背景变化的鲁棒性。针对第一人称视频中时空冗余信息的高计算成本问题,设计了一种管状动态掩码策略,在训练阶段随机遮蔽输入序列的时空块以模拟真实场景信息缺失,同时在推理阶段通过集成多组掩码预测结果补偿性能损失,实现了计算效率与识别精度的平衡。模型创新性地构建了双阶段预训练框架,在源域与目标域联合优化特征重构与类别判别目标,利用跨域数据互补性建立共享表征空间,有效缓解了工业场景与日常生活场景间的极端域差异。此外,该方法将多模态学习与少样本适应解耦,在元训练阶段通过冻结教师模型参数实现稳定知识迁移,而在元测试阶段仅需微调轻量级分类器即可快速适应新类别,这种模块化设计大幅降低了实际部署复杂度。此外,尽管该模型达到很好的动作识别效果,但是它仍然存在下面的不足:依赖目标域多模态完整性、掩码破坏时空连续性、域共享特征假设理想化、动态视角偏移处理不足和目标域数据需求被低估。
Abstract
MM‐CDFSL is a model designed for first‐person action recognition that leverages cross‐domain few-shot learning. Its core innovation lies in the deep integration of multi-modal collaborative learning with a cross-domain dynamic inference mechanism. This approach is the first to introduce knowledge distillation from multi-modal teacher models—such as those using optical flow and hand pose heat maps—into an RGB student model. By utilizing unlabeled target domain data for cross-modal feature alignment, the single-modal inference model is enabled to inherit the robustness of multi-source information against variations in lighting and background. To tackle the high computational cost associated with the redundant spatiotemporal information in first-person videos, the model introduces a tubular dynamic masking strategy. During training, random spatiotemporal blocks of the input sequence are masked to simulate the loss of information in real-world scenarios, while at inference time, ensemble predictions from multiple masked versions compensate for any performance degradation, thereby balancing computational efficiency and recognition accuracy.The model also innovatively constructs a two-stage pre-training framework that jointly optimizes feature reconstruction and category discrimination objectives across both source and target domains. By leveraging the complementary nature of cross-domain data, a shared representation space is established, effectively mitigating the extreme domain differences between industrial settings and everyday scenarios.Furthermore, the approach decouples multi-modal learning from few-shot adaptation. During meta-training, stable knowledge transfer is achieved by freezing the teacher model parameters, and during meta-testing, only a lightweight classifier requires fine-tuning to quickly adapt to new categories. This modular design significantly reduces the complexity of practical deployment.However, despite its strong performance in action recognition, the model still faces several shortcomings: it depends on the completeness of multi-modal data in the target domain, the masking strategy disrupts spatiotemporal continuity, it assumes an idealized scenario for shared domain features, it inadequately handles dynamic viewpoint shifts, and it underestimates the target domain’s data requirements.
1. 引言
以自我为中心的视频数据集在某些特定领域是非常稀缺的,因此需要使用跨领域方法将从大规模数据集上训练得到的知识迁移到目标领域。由于目标领域的动作可能与原始领域的动作不同,而对目标领域的数据集进行标注又是非常耗时的,解决该问题的方法是使用跨领域小样本学习,该方法融合了跨领域的适应性和小样本学习的效率。跨领域小样本学习采用两阶段训练,分别是预训练和域适应;同时采用两阶段预测,分别是小样本训练和推理。
2. 框架
MM-CDFSL的流程:首先,模型在元训练阶段进行多模态预训练与知识蒸馏:基于VideoMAE架构对RGB、光流、手部姿态等模态分别进行独立训练,通过重构损失与分类损失的联合优化,学习源域与目标域共享的判别性特征。随后,引入多模态蒸馏机制,将预训练好的光流与手部姿态教师模型的特征知识迁移至学生RGB模型中,通过未标注目标域数据的特征对齐损失(如L2距离)增强模型对目标环境的适应性,同时冻结教师模型参数以确保知识迁移的稳定性。完成特征蒸馏后,模型进入元测试阶段的少样本适应:利用目标域支持集中的少量标注样本(如N-way K-shot数据)微调轻量级分类器,保留学生RGB编码器的参数以维持跨域特征表达能力。为应对实际部署中的计算瓶颈,模型在推理时采用动态管状掩码策略,随机遮蔽输入视频的时空块以减少输入令牌数量,并通过集成多组掩码版本的预测结果补偿信息损失,在保证识别精度的同时显著提升推理速度。
2.1 任务目标
跨领域小样本学习的输入是原始领域的多模态数据和未标注的目标数据,任务是在利用已标记的原始数据集 D S D_S DS和未标记的目标数据集 D T u D_{T_u} DTu的情况下对目标数据集 D T D_T DT中出现的新类完成分类任务。 D S D_S DS和 D T u D_{T_u} DTu都是3个模态(RGB、光流和手势的热力图)的数据集,并且 D S D_S DS和 D T D_T DT之间没有相同的类。为了推理目标数据集中的新类,数据集 D T u D_{T_u} DTu被划分成一个包含N个类、数据样本数为K的支持集 S S S和一个只由支持集中N个类构成的查询集 Q Q Q。
2.2 领域自适应和类别区分特征的预训练
预训练中使用的patch嵌入先将输入 [ B , T , C , H , W ] [B, T, C, H, W] [B,T,C,H,W]转换成 [ B , C , T , H , W ] [B, C, T, H, W] [B,C,T,H,W],然后进行卷积核大小为 T 1 × H 1 × W 1 T_1\times H_1\times W_1 T1×H1×W1,步长为 T 1 × H 1 × W 1 T_1\times H_1\times W_1 T1×H1×W1的三维卷积,得到形状为 [ B , C , ⌊ T T 1 ⌋ , ⌊ H H 1 ⌋ , ⌊ W W 1 ⌋ ] [B, C, \lfloor\frac{T}{T_1}\rfloor,\lfloor\frac{H}{H_1}\rfloor,\lfloor\frac{W}{W_1}\rfloor] [B,C,⌊T1T⌋,⌊H1H⌋,⌊W1W⌋]的输出,再拉直成 [ B , C , N ] [B, C, N] [B,C,N]后,最终转换成 [ B , N , C ] [B, N, C] [B,N,C]。patch嵌入的结果再进行Transformer位置嵌入得到的结果与patch嵌入相加得到最终嵌入。
预训练中的掩码在 ρ p r e t r a i n \rho_{pretrain} ρpretrain的指导下进行tube掩码, N × ρ p r e t r a i n N\times\rho_{pretrain} N×ρpretrain取整后为需要掩码的块 M M M,生成 M M M个1,再与 N − M N-M N−M个0拼接后打乱,就得到了掩码序列。对掩码序列布尔化后取反就得到要输入编码器 ϵ \epsilon ϵ的数据。编码器 ϵ \epsilon ϵ由堆叠的Transformer编码器块构成。编码器的输出与未输入编码器的输入分别进行位置嵌入后进行拼接输入到解码器 D D D。解码器D也是由堆叠的Transformer编码器块构成,最后再添加一个线性层将嵌入维度转换为 C × H 1 × W 1 C\times H_1\times W_1 C×H1×W1。
上述编码器和解码器构成的框架是VideoMAE,因此可用该模型的预训练权重。假设给定一个输入 x m ∈ R T × H m × W m × C m x_m\in R^{T\times H_m\times W_m\times C_m} xm∈RT×Hm×Wm×Cm,其中 m ∈ m\in m∈{RGB,光流,手势的热力图},该输入经过VideoMAE后得到重构的输入:
x ^ m = D m ( ϵ m ( ψ ( x m ) ) ) . \hat{x}_m=D_m(\epsilon_m(\psi(x_m))). x^m=Dm(ϵm(ψ(xm))).
该输入经过编码器后进行分类得到的类别:
l m = g m ( ϵ m ( ψ ( x m ) ) ) . l_m=g_m(\epsilon_m(\psi(x_m))). lm=gm(ϵm(ψ(xm))).
该部分的损失函数由三部分构成:第一部分是原始数据集的重构损失,第二部分是目标数据集的重构损失,第三部分是原始数据集上动作识别的交叉熵损失。
L p r e t r a i n = L r e c o n s o u r c e + L r e c o n t a r g e t + λ c e m L c e s o u r c e . L_{pretrain}=L_{recon}^{source}+L_{recon}^{target}+\lambda_{ce_m}L_{ce}^{source}. Lpretrain=Lreconsource+Lrecontarget+λcemLcesource.
该部分框架的代码:
class PretrainVisionTransformer(nn.Module):
"""Vision Transformer with support for patch or hybrid CNN input stage"""
def __init__(
self,
img_size=224,
patch_size=16,
encoder_in_chans=3,
encoder_num_classes=0,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
decoder_num_classes=1536, # decoder_num_classes=768,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=8,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
init_values=0.0,
use_learnable_pos_emb=False,
use_checkpoint=False,
tubelet_size=2,
num_classes=0, # avoid the error from create_fn in timm
in_chans=0, # avoid the error from create_fn in timm
fc_drop_rate=0.5,
use_mean_pooling=True,
num_classes_action=204,
):
super().__init__()
self.encoder = PretrainVisionTransformerEncoder(
img_size=img_size,
patch_size=patch_size,
in_chans=encoder_in_chans,
num_classes=encoder_num_classes,
embed_dim=encoder_embed_dim,
depth=encoder_depth,
num_heads=encoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values,
tubelet_size=tubelet_size,
use_checkpoint=use_checkpoint,
use_learnable_pos_emb=use_learnable_pos_emb,
)
self.decoder = PretrainVisionTransformerDecoder(
patch_size=patch_size,
out_chans=encoder_in_chans,
num_patches=self.encoder.patch_embed.num_patches,
num_classes=decoder_num_classes,
embed_dim=decoder_embed_dim,
depth=decoder_depth,
num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values,
tubelet_size=tubelet_size,
use_checkpoint=use_checkpoint,
)
self.fc_dropout = (
nn.Dropout(p=fc_drop_rate) if fc_drop_rate > 0 else nn.Identity()
)
self.fc_norm = norm_layer(encoder_embed_dim) if use_mean_pooling else None
self.head_action = nn.Linear(encoder_embed_dim, num_classes_action)
self.encoder_to_decoder = nn.Linear(
encoder_embed_dim, decoder_embed_dim, bias=False
)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.pos_embed = get_sinusoid_encoding_table(
self.encoder.patch_embed.num_patches, decoder_embed_dim
)
trunc_normal_(self.mask_token, std=0.02)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token", "mask_token"}
def forward(self, x, mask):
_, _, T, _, _ = x.shape
x_vis = self.encoder(x, mask) # [B, N_encoded, C_e]
# classifier branch
x = self.fc_norm(x_vis.mean(1))
logits = self.head_action(self.fc_dropout(x))
# decoder branch
x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
B, N, C = x_vis.shape
# we don't unshuffle the correct visible token order,
# but shuffle the pos embedding accorddingly.
expand_pos_embed = (
self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
)
pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
x_full = torch.cat(
[x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1
) # [B, N, C_d]
x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]
return x, logits
2.3 多模态蒸馏
该部分通过将域适应的多模态特征提取到RGB编码器来提升对目标领域的适应性,这样的多模态蒸馏操作可以在缓解领域差异的同时减少模型的复杂度,进一步消除了推理过程中处理和集成多模态数据类型的开销。
假定未标记的目标RGB输入为 x R G B t a r g e t x_{RGB}^{target} xRGBtarget,并且 ϵ R G B s t u d e n t \epsilon_{RGB}^{student} ϵRGBstudent、 ϵ R G B t e a c h e r \epsilon_{RGB}^{teacher} ϵRGBteacher、 ϵ f l o w t e a c h e r \epsilon_{flow}^{teacher} ϵflowteacher、 ϵ p o s e t e a c h e r \epsilon_{pose}^{teacher} ϵposeteacher都使用上一节训练完成后的权重,则多模态蒸馏的计算过程如下:
f ^ R G B = M R G B ( ϵ R G B s t u d e n t ( ψ ( x R G B t a r g e t ) ) ) f ^ f l o w = M f l o w ( ϵ R G B s t u d e n t ( ψ ( x R G B t a r g e t ) ) ) f ^ p o s e = M p o s e ( ϵ R G B s t u d e n t ( ψ ( x R G B t a r g e t ) ) ) f R G B = ϵ R G B t e a c h e r ( ψ ( x R G B t a r g e t ) ) f f l o w = ϵ f l o w t e a c h e r ( ψ ( x R G B t a r g e t ) ) f p o s e = ϵ p o s e t e a c h e r ( ψ ( x R G B t a r g e t ) ) . \begin{aligned} \hat{f}_{RGB}&=M_{RGB}(\epsilon_{RGB}^{student}(\psi(x_{RGB}^{target})))\\ \hat{f}_{flow}&=M_{flow}(\epsilon_{RGB}^{student}(\psi(x_{RGB}^{target})))\\ \hat{f}_{pose}&=M_{pose}(\epsilon_{RGB}^{student}(\psi(x_{RGB}^{target})))\\ f_{RGB}&=\epsilon_{RGB}^{teacher}(\psi(x_{RGB}^{target}))\\ f_{flow}&=\epsilon_{flow}^{teacher}(\psi(x_{RGB}^{target}))\\ f_{pose}&=\epsilon_{pose}^{teacher}(\psi(x_{RGB}^{target})). \end{aligned} f^RGBf^flowf^posefRGBfflowfpose=MRGB(ϵRGBstudent(ψ(xRGBtarget)))=Mflow(ϵRGBstudent(ψ(xRGBtarget)))=Mpose(ϵRGBstudent(ψ(xRGBtarget)))=ϵRGBteacher(ψ(xRGBtarget))=ϵflowteacher(ψ(xRGBtarget))=ϵposeteacher(ψ(xRGBtarget)).
其中 M R G B M_{RGB} MRGB、 M f l o w M_{flow} Mflow和 M p o s e M_{pose} Mpose都是多层感知机,代码中只使用了一层线性层。
单个模态的特征蒸馏损失定义如下:
L f d m = ∥ s g [ f m ] − f ^ m ∥ 2 2 . L_{fd_m}=\left\|sg[f_m]-\hat{f}_m\right\|_2^2. Lfdm=
sg[fm]−f^m
22.
则该部分的总损失为:
L d i s t i l l = L f d R G B + L f d f l o w + L f d p o s e . L_{distill}=L_{fd_{RGB}}+L_{fd_{flow}}+L_{fd_{pose}}. Ldistill=LfdRGB+Lfdflow+Lfdpose.
该部分多模态蒸馏的代码如下:
class MMDistillTrainer(pl.LightningModule):
def __init__(self, cfg):
super(MMDistillTrainer, self).__init__()
self.cfg = cfg
# model
self.student_rgb = get_model(
cfg,
ckpt_pth=cfg.trainer.ckpt_path[0],
input_size=cfg.data_module.input_size[0],
patch_size=cfg.data_module.patch_size[0][0],
in_chans=cfg.trainer.in_chans[0],
)
self.teacher_flow = get_model(
cfg,
ckpt_pth=cfg.trainer.ckpt_path[1],
input_size=cfg.data_module.input_size[1],
patch_size=cfg.data_module.patch_size[1][0],
in_chans=cfg.trainer.in_chans[1],
)
self.teacher_pose = get_model(
cfg,
ckpt_pth=cfg.trainer.ckpt_path[2],
input_size=cfg.data_module.input_size[2],
patch_size=cfg.data_module.patch_size[2][0],
in_chans=cfg.trainer.in_chans[2],
)
self.cmt = CrossModalTranslate()
self.teacher_rgb = copy.deepcopy(self.student_rgb)
self.teacher_rgb.requires_grad_(False)
self.teacher_flow.requires_grad_(False)
self.teacher_pose.requires_grad_(False)
# loss
self.ce_loss = nn.CrossEntropyLoss()
self.mse_loss = nn.MSELoss()
self.train_top1_a = torchmetrics.Accuracy(
task="multiclass", num_classes=self.cfg.data_module.num_classes_action
)
# initialization
self.training_step_outputs = []
self.validation_step_outputs = []
self.test_step_outputs = []
def configure_optimizers(self):
self.scale_lr()
self.trainer.fit_loop.setup_data()
dataset = self.trainer.train_dataloader.dataset
self.niter_per_epoch = len(dataset) // self.total_batch_size
print("Number of training steps = %d" % self.niter_per_epoch)
print(
"Number of training examples per epoch = %d"
% (self.total_batch_size * self.niter_per_epoch)
)
optimizer, scheduler = get_optimizer_mmdistill(
self.cfg.trainer, [self.student_rgb, self.cmt], self.niter_per_epoch
)
return [optimizer], [scheduler]
def lr_scheduler_step(self, scheduler, metric):
cur_iter = self.trainer.global_step
next_lr = scheduler.get_epoch_values(cur_iter + 1)[0]
for param_group in self.trainer.optimizers[0].param_groups:
param_group["lr"] = next_lr
def _forward_loss_action(
self,
unlabel_frames_rgb_w,
unlabel_frames_flow_w,
unlabel_frames_pose_w,
mask=None,
):
# feature distillation
fr, _ = self.teacher_rgb(unlabel_frames_rgb_w, mask)
ff, _ = self.teacher_flow(unlabel_frames_flow_w, mask)
fp, _ = self.teacher_pose(unlabel_frames_pose_w, mask)
x_rgb, _ = self.student_rgb(unlabel_frames_rgb_w, mask)
trans_rgb, trans_flow, trans_pose = self.cmt(x_rgb)
trans_loss_rgb = self.mse_loss(trans_rgb, fr.detach())
trans_loss_flow = self.mse_loss(trans_flow, ff.detach())
trans_loss_pose = self.mse_loss(trans_pose, fp.detach())
return trans_loss_rgb, trans_loss_flow, trans_loss_pose
def training_step(self, batch, batch_idx):
input = batch
unlabel_frames_rgb_w = input["unlabel_frames_rgb"]
unlabel_frames_flow_w = input["unlabel_frames_flow"]
unlabel_frames_pose_w = input["unlabel_frames_pose"]
bool_masked_pos = input["mask"]
bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)
trans_loss_rgb, trans_loss_flow, trans_loss_pose = self._forward_loss_action(
unlabel_frames_rgb_w,
unlabel_frames_flow_w,
unlabel_frames_pose_w,
bool_masked_pos,
)
loss = trans_loss_rgb + trans_loss_flow + trans_loss_pose
outputs = {
"train_loss": loss.item(),
"trans_loss_rgb": trans_loss_rgb.item(),
"trans_loss_flow": trans_loss_flow.item(),
"trans_loss_pose": trans_loss_pose.item(),
}
self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"])
self.log_dict(outputs)
return loss
def on_train_epoch_start(self):
# shuffle the unlabel data loader
unlabel_dir_to_img_frame = (
self.trainer.train_dataloader.dataset.unlabel_loader._dir_to_img_frame
)
unlabel_start_frame = (
self.trainer.train_dataloader.dataset.unlabel_loader._start_frame
)
lists = list(zip(unlabel_dir_to_img_frame, unlabel_start_frame))
random.shuffle(lists)
unlabel_dir_to_img_frame, unlabel_start_frame = zip(*lists)
self.trainer.train_dataloader.dataset.unlabel_loader._dir_to_img_frame = list(
unlabel_dir_to_img_frame
)
self.trainer.train_dataloader.dataset.unlabel_loader._start_frame = list(
unlabel_start_frame
)
def validation_step(self, batch, batch_idx):
input = batch[0]
frames_rgb = input["frames"]
action_idx = input["action_idx"]
bool_masked_pos = input["mask"]
bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)
# convert labels for fewshot evaluation
_, action_idx = torch.unique(action_idx, return_inverse=True)
n_way = self.cfg.data_module.n_way
k_shot = self.cfg.data_module.k_shot
q_sample = self.cfg.data_module.q_sample
# RGB
frames_rgb, support_frames_rgb, query_frames_rgb = self.preprocess_frames(
frames=frames_rgb, n_way=n_way, k_shot=k_shot, q_sample=q_sample
)
# mask
support_mask = bool_masked_pos[: k_shot * n_way]
query_mask = bool_masked_pos[k_shot * n_way :]
action_idx = action_idx.view(n_way, (k_shot + q_sample))
support_action_label, query_action_label = (
action_idx[:, :k_shot].flatten(),
action_idx[:, k_shot:].flatten(),
)
pred_rgb, prob_rgb = self.LR(
self.student_rgb,
support=support_frames_rgb,
support_label=support_action_label,
query=query_frames_rgb,
support_mask=support_mask,
query_mask=query_mask,
)
acc = torchmetrics.Accuracy(task="multiclass", num_classes=5)
top1_action = acc(pred_rgb.cpu(), query_action_label.cpu())
outputs = {
"top1_action": top1_action.item(),
}
self.validation_step_outputs.append(outputs)
def on_validation_epoch_end(self):
top1_action = np.mean(
[output["top1_action"] for output in self.validation_step_outputs]
)
self.log("val_top1_action", top1_action, on_step=False)
self.validation_step_outputs.clear()
def test_step(self, batch, batch_idx):
input = batch[0]
frames_rgb = input["frames"]
action_idx = input["action_idx"]
bool_masked_pos = input["mask"]
bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)
# convert labels for fewshot evaluation
_, action_idx = torch.unique(action_idx, return_inverse=True)
n_way = self.cfg.data_module.n_way
k_shot = self.cfg.data_module.k_shot
q_sample = self.cfg.data_module.q_sample
# RGB
frames_rgb, support_frames_rgb, query_frames_rgb = self.preprocess_frames(
frames=frames_rgb, n_way=n_way, k_shot=k_shot, q_sample=q_sample
)
# mask
support_mask = bool_masked_pos[: k_shot * n_way]
query_masks = []
for _ in range(2):
query_mask = bool_masked_pos[k_shot * n_way :]
query_masks.append(query_mask)
# Shift by 1 in the batch dimension
bool_masked_pos = torch.cat(
(bool_masked_pos[1:], bool_masked_pos[:1]), dim=0
)
action_idx = action_idx.view(n_way, (k_shot + q_sample))
support_action_label, query_action_label = (
action_idx[:, :k_shot].flatten(),
action_idx[:, k_shot:].flatten(),
)
# # prediction with no mask
# pred_rgb, prob_rgb = self.LR(
# self.student_rgb,
# support=support_frames_rgb,
# support_label=support_action_label,
# query=query_frames_rgb,
# )
# prediction with mask and ensemble
pred_rgb_ensemble, prob_rgb_original = self.LR_ensemble(
self.teacher_rgb,
support=support_frames_rgb,
support_label=support_action_label,
query=query_frames_rgb,
support_mask=support_mask,
query_masks=query_masks[:2],
)
acc = torchmetrics.Accuracy(task="multiclass", num_classes=5)
# top1_action = acc(pred_rgb.cpu(), query_action_label.cpu())
top1_action_ensemble = acc(pred_rgb_ensemble.cpu(), query_action_label.cpu())
outputs = {
# "top1_action": top1_action.item(),
"top1_action_ensemble": top1_action_ensemble.item(),
}
self.test_step_outputs.append(outputs)
def on_test_epoch_end(self):
top1_action_ensemble = np.mean(
[output["top1_action_ensemble"] for output in self.test_step_outputs]
)
top1_action_ensemble_std = np.std(
[output["top1_action_ensemble"] for output in self.test_step_outputs]
)
top1_action_ensemble_std_error = top1_action_ensemble_std / np.sqrt(
len(self.test_step_outputs)
)
self.log("top1_action_ensemble", top1_action_ensemble, on_step=False)
self.log("top1_action_ensemble_std", top1_action_ensemble_std, on_step=False)
self.log(
"top1_action_ensemble_std_error",
top1_action_ensemble_std_error,
on_step=False,
)
def scale_lr(self):
self.total_batch_size = self.cfg.batch_size * len(self.cfg.devices)
self.cfg.trainer.lr = self.cfg.trainer.lr * self.total_batch_size / 256
self.cfg.trainer.min_lr = self.cfg.trainer.min_lr * self.total_batch_size / 256
self.cfg.trainer.warmup_lr = (
self.cfg.trainer.warmup_lr * self.total_batch_size / 256
)
print("LR = %.8f" % self.cfg.trainer.lr)
print("Batch size = %d" % self.total_batch_size)
def preprocess_frames(self, frames, n_way, k_shot, q_sample):
frames = rearrange(
frames, "(n m) c t h w -> n m c t h w", n=n_way, m=(k_shot + q_sample)
)
support_frames = rearrange(
frames[:, :k_shot],
"n m c t h w -> (n m) c t h w",
n=n_way,
m=k_shot,
)
query_frames = rearrange(
frames[:, k_shot:],
"n m c t h w -> (n m) c t h w",
n=n_way,
m=q_sample,
)
return frames, support_frames, query_frames
@torch.no_grad()
def LR(
self,
model,
support,
support_label,
query,
support_mask=None,
query_mask=None,
norm=False,
):
"""logistic regression classifier"""
support = model(support, support_mask)[0].detach()
query = model(query, query_mask)[0].detach()
if norm:
support = normalize(support)
query = normalize(query)
clf = sklearn.linear_model.LogisticRegression(
random_state=0,
solver="lbfgs",
max_iter=1000,
C=1,
multi_class="multinomial",
)
support_features_np = support.data.cpu().numpy()
support_label_np = support_label.data.cpu().numpy()
clf.fit(support_features_np, support_label_np)
query_features_np = query.data.cpu().numpy()
pred = clf.predict(query_features_np)
prob = clf.predict_proba(query_features_np)
pred = torch.from_numpy(pred).type_as(support)
prob = torch.from_numpy(prob).type_as(support)
return pred, prob
@torch.no_grad()
def LR_ensemble(
self,
model,
support,
support_label,
query,
support_mask=None,
query_masks=None,
norm=False,
):
"""logistic regression classifier"""
support = model(support, support_mask)[0].detach()
clf = sklearn.linear_model.LogisticRegression(
random_state=0,
solver="lbfgs",
max_iter=1000,
C=1,
multi_class="multinomial",
)
support_features_np = support.data.cpu().numpy()
support_label_np = support_label.data.cpu().numpy()
clf.fit(support_features_np, support_label_np)
probs = []
for query_mask in query_masks:
query_features = model(query, query_mask)[0].detach()
query_features_np = query_features.data.cpu().numpy()
prob = clf.predict_proba(query_features_np)
probs.append(prob)
probs = np.array(probs)
prob = np.mean(probs, axis=0)
pred = np.argmax(prob, axis=1)
pred = torch.from_numpy(pred).type_as(support)
prob = torch.from_numpy(prob).type_as(support)
return pred, prob
2.4 小样本适应和掩码集成推理
2.4.1 小样本适应
将目标领域支持集上的数据经过2.1提到的嵌入和tube掩码后送入蒸馏完成后的 ϵ R G B s t u d e n t \epsilon_{RGB}^{student} ϵRGBstudent,再将得到的特征从批次维度选择一个作为分类器 g ′ g' g′的训练样本,进行分类器的训练。代码中使用的分类器是逻辑回归分类器。
2.4.2 掩码集成推理
在完成小样本训练后,将查询集上的数据嵌入完成后进行多次不同的tube掩码的结果作为 ϵ R G B s t u d e n t \epsilon_{RGB}^{student} ϵRGBstudent的输入,再将得到的特征从批次维度选择一个作为分类器 g ′ g' g′的测试样本,进行分类器的预测,最终将所有结果平均后得到最终的分类结果。
3. 创新点和不足
3.1 创新点
论文的核心创新点在于构建了一个面向自我中心动作识别的多模态跨域小样本学习范式,通过多模态协同与轻量化推理机制突破传统方法的局限性。该方法首次将多模态蒸馏机制引入跨域小样本学习场景,利用光流、手部姿态热图等模态的教师模型对RGB学生模型进行知识迁移,通过未标记目标数据的自监督对齐缓解源域与目标域(如工业场景与日常生活)的极端视觉差异。其创新性地将模态互补性转化为领域适应能力,例如光流模态对运动变化的敏感性可弥补RGB模态在光照变化下的脆弱性,而手部姿态热图则提供精确的局部动作表征,最终使单一RGB模态在推理阶段即可继承多模态的鲁棒性。此外,提出的集成掩码推理技术通过随机遮蔽输入时空块并集成多组预测,实现了计算效率与识别精度的动态平衡,突破传统Transformer模型在资源受限设备上的部署瓶颈。
3.2 不足
论文虽然在跨域适应与计算效率上取得突破,但仍存在若干核心局限。首先,其多模态蒸馏机制高度依赖目标域未标注数据的模态完整性,若实际应用中目标域缺失光流或手部姿态等关键模态(如低分辨率摄像头无法提取光流),模型性能将显著下降,且缺乏动态模态补偿机制。其次,尽管集成掩码推理提升了计算效率,但随机遮蔽策略可能破坏动作的时空连续性,例如快速手部操作或短时物体交互的关键帧被遮蔽后,集成预测难以恢复完整动作语义,导致对时序敏感类别的识别稳定性不足。此外,该方法假设源域与目标域的底层动作语义空间高度对齐,但在极端领域差异场景(如工业机械操作与日常厨房动作),共享的预训练特征可能无法捕捉领域特异性动作规律,需依赖大量目标域未标注数据重新对齐,而现实场景中目标域数据获取成本常被低估。最后,框架未充分考虑自我中心视频的动态视角偏移问题,当穿戴设备因运动产生剧烈视角变化时,固定的位置编码与时空块划分策略可能导致局部动作特征失真,影响跨视角泛化能力。
参考
Masashi Hatano, Ryo Hachiuma, Ryo Fujii, and Hideo Saito. Multimodal Cross-Domain Few-Shot Learning for Egocentric Action Recognition.
代码来源:https://github.com/masashi-hatano/MM-CDFSL
总结
MM-CDFSL的流程:首先,模型在元训练阶段进行多模态预训练与知识蒸馏:基于VideoMAE架构对RGB、光流、手部姿态等模态分别进行独立训练,通过重构损失与分类损失的联合优化,学习源域与目标域共享的判别性特征。随后,引入多模态蒸馏机制,将预训练好的光流与手部姿态教师模型的特征知识迁移至学生RGB模型中,通过未标注目标域数据的特征对齐损失(如L2距离)增强模型对目标环境的适应性,同时冻结教师模型参数以确保知识迁移的稳定性。
完成特征蒸馏后,模型进入元测试阶段的少样本适应:利用目标域支持集中的少量标注样本(如N-way K-shot数据)微调轻量级分类器,保留学生RGB编码器的参数以维持跨域特征表达能力。为应对实际部署中的计算瓶颈,模型在推理时采用动态管状掩码策略,随机遮蔽输入视频的时空块以减少输入令牌数量,并通过集成多组掩码版本的预测结果补偿信息损失,在保证识别精度的同时显著提升推理速度。此外,尽管该模型达到很好的动作识别效果,但是它仍然存在下面的不足:依赖目标域多模态完整性、掩码破坏时空连续性、域共享特征假设理想化、动态视角偏移处理不足和目标域数据需求被低估。