【领域泛化】论文介绍《Domain generalization via multidomain discriminant analysis》
论文地址
http://proceedings.mlr.press/v115/hu20a/hu20a.pdf
摘要
在领域泛化(DG) 中,有一个很常见的假设,就是分布偏移只存在于边缘分布 P ( X ) P(X) P(X),即只发生先验偏移(Prior Shift),没有发生概念偏移(Concept Shift),条件分布(后验分布) P ( Y ∣ X ) P(Y|X) P(Y∣X)不同域是不变的,这就是为什么大多数DIR要对齐表示空间的边缘分布 P ( X ) P(X) P(X)。
但从因果分析的角度来看,只有当 X X X是 Y Y Y的原因的时候,这种对齐才是有效的,但对于很多任务,特别是分类任务, Y Y Y通常是 X X X的原因,本文正是这类研究对齐类条件 P ( X ∣ Y ) P(X|Y) P(X∣Y)方法的典型,提出了 P ( Y ) P(Y) P(Y)随着 P ( X ∣ Y ) P(X|Y) P(X∣Y)一起变化情况下的领域泛化方法。
本文提出一种多域判别分析(MDA),旨在最小化同类中不同域的的分歧,最大化类之间的可分离性,以及整体所有类的紧致性,来尝试进行更好的领域泛化。
核心思想
动机
如果 Y Y Y是 X X X的原因,那么 P ( Y ) P(Y) P(Y)边缘分布和 P ( X ∣ Y ) P(X|Y) P(X∣Y)条件分布会彼此“独立”,因为 P ( X ∣ Y ) P(X|Y) P(X∣Y)不包含 P ( Y ) P(Y) P(Y)的信息。但在跨域的情况下, P ( X ∣ Y ) P(X|Y) P(X∣Y)和 P ( Y ) P(Y) P(Y)会耦合地发生变化。
本文提出一种方法,适用于 P ( X ∣ Y ) P(X|Y) P(X∣Y)和 P ( Y ) P(Y) P(Y)跨域变化的领域泛化任务,该方法关注类的可分离性,不强制对齐表征空间的边缘分布(DIRs类方法),这样可以放松 P ( Y ) P(Y) P(Y)稳定的约束(另一篇2018年论文的工作,本文在它的基础上放松了约束,更具有适用性)。
本文主要做出了两个贡献:
- 提出一种新的度量方式——平均类差异,将该度量方式和其他三种度量统一到一个目标中进行学习。
- 本文推导了在基于核的域不变特征学习变换方法中超额风险和泛化误差的界限,该贡献是在领域泛化中对超额风险提供理论支持的最早研究之一。(该部分并不是重点,日后有时间补上)
正则化约束
平均域差异
我们首先考虑最小化在所有域的每个类的类条件分布 P s ( X ∣ Y = j ) P^{s}(X|Y=j) Ps(X∣Y=j)的差异,对于m个域,c个类别,核均值嵌入 μ j s \mu_{j}^{s} μjs表示 P s ( X ∣ Y = j ) P^{s}(X|Y=j) Ps(X∣Y=j), H H H表示再生希尔伯特空间(RKHS),平均域差异 L a d d L_{add} Ladd的定义如下:
L a d d = 1 C 2 m ∑ j = 1 c ∑ 1 ≤ s ≤ s ′ ≤ m ∣ ∣ μ j s − μ j s ′ ∣ ∣ H 2 L_{add} = \frac{1}{C_{2}^{m}}\sum_{j=1}^{c}\sum_{1 \leq s \leq s' \leq m} || \mu_{j}^{s} - \mu_{j}^{s'}||_{H}^2 Ladd=C2m1j=1∑c1≤s≤s′≤m∑∣∣μjs−μjs′∣∣H2
其中, ∣ ∣ μ j s − μ j s ′ ∣ ∣ H || \mu_{j}^{s} - \mu_{j}^{s'}||_{H} ∣∣μjs−μjs′∣∣H是MMD损失,这里累和的是该损失的二范式。可知, L a d d = 0 L_{add} = 0 Ladd=0仅在 P j 1 = P j 2 = P j 3 . . . = P j m P^{1}_{j} = P^{2}_{j} = P^{3}_{j} ... = P^{m}_{j} Pj1=Pj2=Pj3...=Pjm时成立,所以能测量多个域的相同类的类条件分布差异。
平均类差异
最小化平均域差异 L a d d L_{add} Ladd会使同一类的不同域数据的类条件分布均值在再生希尔伯特空间(RKHS) 尽可能接近,然而,不同类的类条件分布均值也有可能尽可能接近(在《Respecting domain relations Hypothesis invariance for domain generalization》提及过不同类的表征在DIRs后混淆在一起),这是这类DG方法性能下降的重要原因,因此,本文提出了平均类差异 L a c d L_{acd} Lacd来区别开不同的类的表征。
L a c d = 1 C 2 c ∑ 1 ≤ j ≤ j ′ ≤ c ∣ ∣ ∑ s = 1 m P ( S = s ∣ Y = j ) μ j s − ∑ s = 1 m P ( S = s ∣ Y = j ) μ j s ′ ∣ ∣ H 2 L_{acd} = \frac{1}{C_{2}^{c}}\sum_{1 \leq j \leq j' \leq c} || \sum_{s = 1 }^{m}P(S=s|Y=j) \mu_{j}^{s} - \sum_{s = 1 }^{m}P(S=s|Y=j)\mu_{j}^{s'}||_{H}^2 Lacd=C2c11≤j≤j′≤c∑∣∣s=1∑mP(S=s∣Y=j)μjs−s=1∑mP(S=s∣Y=j)μjs′∣∣H2
合并实例级信息
某些微妙的信息,例如分布的紧致性,在上述两个损失中无法捕捉到,这里提出额外两个措施多域类间散射和多域类内散射来解决。
多域类间散射
设样本为来自m个领域的n个实例,每个实例由c个类组成,多域类间散射为:
L m b s = 1 n ∑ j = 1 c n j ∣ ∣ u j − u ˉ j ∣ ∣ H 2 L_{mbs} = \frac{1}{n}\sum_{j = 1 }^{c}n_{j} ||u_{j} - \bar u_{j}||_{H}^2 Lmbs=n1j=1∑cnj∣∣uj−uˉj∣∣H2
L m b s L_{mbs} Lmbs和 L a c d L_{acd} Lacd都衡量了不同阶层分布之间的差异, L m b s L_{mbs} Lmbs就如同一个简单的池化方案,将同一类的所有实例聚集在一起。
多域类内散射
设样本为来自m个领域的n个实例,每个实例由c个类组成,多域类内散射为:
L m w s = 1 n ∑ j = 1 c ∑ s = 1 m ∑ i = 1 n j s ∣ ∣ ϕ ( x i ∈ j s ) − u j ∣ ∣ H 2 L_{mws} = \frac{1}{n}\sum_{j = 1 }^{c}\sum_{s= 1 }^{m}\sum_{i= 1 }^{n_{j}^{s}} ||\phi(x_{i\in j}^{s}) - u_{j}||_{H}^2 Lmws=n1j=1∑cs=1∑mi=1∑njs∣∣ϕ(xi∈js)−uj∣∣H2
其中, x i ∈ j s x_{i\in j}^{s} xi∈js表示类 j j j在领域 s s s中的第 i i i个实例的特征向量, n j s n_{j}^{s} njs表示领域 s s s中类 j j j的实例总数。
可知,多域类内散射度量的是每个实例的规范特征映射与其所属类在RKHS空间H的均值表示之间的距离之和,与平均域差异 L a d d L_{add} Ladd的区别在于考虑了每个实例的信息。
总结
每一种度量正则化方式都是必要的,缺少将导致次优解。
核方法特征提取
这部分对做深度学习的不是很重要,日后有机会补上
经验估计
在实际场景中,只能从m个源域的有限数量的实例来估计 μ j s \mu_{j}^{s} μjs和 u j u_{j} uj,设 x i ∈ j s x_{i\in j}^{s} xi∈js表示类 j j j在领域 s s s中的第 i i i个实例的特征向量, n j s n_{j}^{s} njs表示领域 s s s中类 j j j的实例总数,所以 μ j s \mu_{j}^{s} μjs可估计为:
μ j s ^ = 1 n j s ∑ i = 1 n j s ϕ ( x i ∈ j s ) \hat{\mu_{j}^{s}} = \frac{1}{n_{j}^{s}}\sum_{i= 1 }^{n_{j}^{s}}{\phi(x_{i\in j}^{s})} μjs^=njs1i=1∑njsϕ(xi∈js)
对于 u j u_{j} uj,需要 P ( S = s ∣ Y = j ) P(S=s|Y=j) P(S=s∣Y=j),该概率可以使用贝叶斯规则进行估计:
P ( S = s ∣ Y = j ) = P ( Y = j ∣ S = s ) ∗ P ( S = s ) P ( Y = j ) P(S=s|Y=j)=\frac{P(Y=j|S=s)*P(S=s)}{P(Y=j)} P(S=s∣Y=j)=P(Y=j)P(Y=j∣S=s)∗P(S=s)
我们可以假设采样所有源域的概率是相等的( ∀ s , P ( S = s ) = 1 m \forall s,P(S=s) = \frac{1}{m} ∀s,P(S=s)=m1),则上式简化为:
P ( S = s ∣ Y = j ) = n j s / n j s ∑ s ′ = 1 m n j s ′ / n j s ′ P(S=s|Y=j)=\frac{n_{j}^{s} / n_{j}^{s}}{\sum_{s'=1}^{m} n_{j}^{s'} / n_{j}^{s'}} P(S=s∣Y=j)=∑s′=1mnjs′/njs′njs/njs
所以, u j u_{j} uj就可以以下式推导:
u j ^ = ∑ s = 1 m n j s / n j s ∑ s ′ = 1 m n j s ′ / n j s ′ μ j s ^ \hat{u_{j}}=\sum_{s=1}^{m}\frac{n_{j}^{s} / n_{j}^{s}}{\sum_{s'=1}^{m} n_{j}^{s'} / n_{j}^{s'}}\hat{\mu_{j}^{s}} uj^=s=1∑m∑s′=1mnjs′/njs′njs/njsμjs^
核方法优化目标
这部分对做深度学习的不是很重要,日后有机会补上
核方法分析超额风险和泛化误差
这部分对做深度学习的不是很重要,日后有机会补上
代码实现
未完成
论文引用
Hu S, Zhang K, Chen Z, et al. Domain generalization via multidomain discriminant analysis[C]//Uncertainty in Artificial Intelligence. PMLR, 2020: 292-302.