SAMformer创新点

发布于:2024-07-03 ⋅ 阅读:(11) ⋅ 点赞:(0)

SAMformer 提供了一种新的方法来改进变换器(Transformer)在时间序列预测任务中的性能,特别是针对泛化能力和训练稳定性问题。具体来说,SAMformer结合了两个关键技术:锐度感知最小化(Sharpness-Aware Minimization, SAM)和通道注意力(Channel-Wise Attention)。以下是详细解释:

1. 锐度感知最小化(SAM)

背景

变换器模型在训练过程中,损失景观(Loss Landscape)通常比较陡峭,导致模型容易陷入局部最小值,影响泛化性能。为了缓解这个问题,Foret等人提出了锐度感知最小化(SAM)。

具体方法

SAM通过考虑参数空间中微小扰动的最大损失来优化模型参数。优化目标从标准的训练损失变为:

L train SAM ( ω ) = max ⁡ ∥ ϵ ∥ ≤ ρ L train ( ω + ϵ ) L_{\text{train}}^{\text{SAM}}(\omega) = \max_{\|\epsilon\| \leq \rho} L_{\text{train}}(\omega + \epsilon) LtrainSAM(ω)=ϵρmaxLtrain(ω+ϵ)

其中:

  • ω \omega ω是模型参数。
  • ϵ \epsilon ϵ是一个在范数约束内的小扰动。
  • ρ \rho ρ是一个超参数,控制扰动的范围。

通过这种方法,模型参数可以在更平滑的损失景观中找到更好的解,从而提高泛化能力。

2. 通道注意力(Channel-Wise Attention)

背景

标准的自注意力机制通常是对所有位置的特征进行加权和组合,而通道注意力机制则专注于特征之间的关系,适合用于多变量时间序列数据。

具体方法

通道注意力机制通过以下方式应用于输入序列:

A ( X ) = softmax ( X W Q W K ⊤ X ⊤ d m ) \mathbf{A}(\mathbf{X}) = \text{softmax} \left( \frac{\mathbf{X} \mathbf{W}_Q \mathbf{W}_K^\top \mathbf{X}^\top}{\sqrt{d_{m}}} \right) A(X)=softmax(dm XWQWKX)

其中:

  • X \mathbf{X} X是输入序列。
  • W Q \mathbf{W}_Q WQ W K \mathbf{W}_K WK是查询和键的权重矩阵。
  • d m d_m dm是注意力机制的维度。

通道注意力机制的特点是它关注输入特征之间的关系,而不是输入序列的位置关系。这使得模型能够更好地捕捉多变量时间序列中的特征依赖性。

SAMformer的具体实现

SAMformer结合了上述两种技术,通过以下方式进行实现:

  1. 输入处理
    对于输入序列 X \mathbf{X} X,首先通过通道注意力机制计算注意力矩阵,并通过残差连接将原始输入与注意力输出相加。

  2. 锐度感知最小化
    在训练过程中,使用SAM优化目标函数,通过考虑参数扰动来优化模型参数,使得模型在更平滑的损失景观中找到全局最优解。

具体步骤

  1. 初始化模型参数 ω \omega ω

  2. 计算通道注意力矩阵 A ( X ) \mathbf{A}(\mathbf{X}) A(X)并结合残差连接得到新的输入表示:
    f ( X ) = [ X + A ( X ) X W V W O ] W f(\mathbf{X}) = [\mathbf{X} + \mathbf{A}(\mathbf{X}) \mathbf{X} \mathbf{W}_V \mathbf{W}_O] \mathbf{W} f(X)=[X+A(X)XWVWO]W

  3. 应用SAM进行参数更新
    在每次参数更新时,计算在小扰动范围内的最大损失,并使用该损失进行参数更新:
    ω = ω − η ∇ L train SAM ( ω ) \omega = \omega - \eta \nabla L_{\text{train}}^{\text{SAM}}(\omega) ω=ωηLtrainSAM(ω)

实验结果

实验表明,SAMformer在常见的多变量时间序列预测任务中,比标准变换器模型和其他基线模型具有更好的性能。特别是在小规模数据集和高噪声数据集上,SAMformer展示了更好的泛化能力和训练稳定性。

结论

SAMformer通过结合锐度感知最小化和通道注意力机制,解决了变换器模型在时间序列预测任务中的泛化能力差和训练不稳定问题。这种方法不仅提高了模型的预测性能,还使得模型在各种复杂的实际应用场景中更加稳健和可靠。