【SegRNN 源码理解】图示理解 forward的过程

发布于:2025-03-06 ⋅ 阅读:(19) ⋅ 点赞:(0)

 

输入: x [16, 60, 7]
(16个批次,每个60个时间步,每步7个特征)
            │
            ▼
┌──────────────────────────────┐
│      RevIN 标准化 + 维度置换     │
│   x = revinLayer(x, 'norm')   │
│      .permute(0, 2, 1)       │
└──────────────┬───────────────┘
               │
               ▼
          x [16, 7, 60]
(16个批次,7个特征,每个特征60个时间步)
            │
            ▼
┌──────────────────────────────┐
│         重塑为分段格式          │
│ x.reshape(-1, seg_num_x, seg_len) │
└──────────────┬───────────────┘
               │
               ▼
          x [112, 5, 12]
(112个序列=16批次×7特征,每个分5段,每段12步)
            │
            ▼
┌──────────────────────────────┐
│          段值嵌入             │
│     x = valueEmbedding(x)     │
│  (Linear: 12 → 512 + ReLU)    │
└──────────────┬───────────────┘
               │
               ▼
          x [112, 5, 512]
(112个序列,5个段,每段表示为512维向量)
            │
            ▼
┌──────────────────────────────┐
│           GRU 编码            │
│      _, hn = self.rnn(x)      │
└──────────────┬───────────────┘
               │
               ▼
          hn [1, 112, 512]
(1层GRU,112个序列的最终隐藏状态)
            │
            ▼
┌────────────────┬─────────────┐
│    RMF 解码     │    PMF 解码   │
└────────┬───────┴──────┬──────┘
         │              │
┌────────▼───────┐  ┌───▼──────────┐
│  循环多步预测    │  │  并行多步预测   │
│(逐段自回归预测)  │  │(一次性预测所有段)│
└────────┬───────┘  └───┬──────────┘
         │              │
         │       ┌──────▼─────────┐
         │       │  位置和通道嵌入   │
         │       │   组合成条件     │
         │       └──────┬─────────┘
         │              │
         │       ┌──────▼─────────┐
         │       │ pos_emb [224, 1, 512] │
         │       │ (224=16×7×2)   │
         │       └──────┬─────────┘
         │              │
         │       ┌──────▼─────────┐
         │       │   条件GRU解码    │
         │       │ _, hy = rnn(pos_emb, hn) │
         │       └──────┬─────────┘
         │              │
         │       ┌──────▼─────────┐
         │       │ hy [1, 224, 512] │
         │       └──────┬─────────┘
         │              │
┌────────▼───────┐ ┌────▼──────────┐
│ 预测 + 堆叠各段  │ │     预测       │
└────────┬───────┘ │ y = predict(hy) │
         │         └────┬───────────┘
         │              │
         └──────┬───────┘
                │
                ▼
           y [16, 7, 24]
 (预测结果: 16个批次,7个特征,预测24个时间步)
                │
                ▼
┌───────────────────────────────┐
│      维度置换 + RevIN反标准化    │
│   y = revinLayer(y.permute(0, 2, 1), 'denorm')   │
└───────────────┬───────────────┘
                │
                ▼
       最终输出 y [16, 24, 7]
(16个批次,预测未来24个时间步,每步7个特征)

SegRNN 前向传播的关键处理阶段解释

1. 数据预处理和视角转换

  • 输入形状[16, 60, 7] (批次, 时间步, 特征)
  • RevIN标准化:对每个特征序列进行可逆实例标准化,消除分布偏移
  • 维度置换:将视角从"时间优先"转为"特征优先" → [16, 7, 60]
  • 作用:为每个特征序列建立独立处理路径

2. 序列分段和嵌入

  • 重塑和分段[16, 7, 60] → [112, 5, 12]
    • 112 = 16(批次) × 7(特征)
    • 每个序列分为5段,每段12个时间步
  • 线性嵌入:将每个12维的段映射到512维空间
  • 作用:引入层次化表示,捕获不同时间尺度的模式

3. GRU序列编码

  • 序列编码:GRU处理5个段之间的时间依赖关系
  • 输出隐藏状态hn [1, 112, 512]
  • 作用:将历史信息压缩到一个固定长度向量

4. 解码阶段(PMF模式)

  • 位置和通道嵌入组合
    • 位置嵌入:[2, 256] → 预测的2个段
    • 通道嵌入:[7, 256] → 7个不同特征
    • 组合后:[224, 1, 512] (224 = 16×7×2)
  • 条件GRU解码
    • 输入:位置嵌入序列
    • 条件:复制并重塑的编码器隐藏状态
    • 输出:hy [1, 224, 512]
  • 预测:线性层 [224, 512] → [224, 12] → 重塑为 [16, 7, 24]
  • 作用:基于历史信息和位置/通道条件,一次生成所有预测段

5. 输出处理

  • 维度置换[16, 7, 24] → [16, 24, 7]
  • RevIN反标准化:将标准化的数据转换回原始分布
  • 最终输出[16, 24, 7] (16个批次,预测24个时间步,每步7个特征)

模型设计亮点

  1. 层次化架构:通过分段+RNN的两级架构,有效处理不同时间尺度
  2. 特征独立处理:每个特征有独立的处理路径,减少干扰
  3. 可逆标准化:RevIN处理分布偏移,保留原始分布特性
  4. 条件生成:位置嵌入+通道嵌入提供细粒度控制的条件生成

这种设计使SegRNN能高效处理长序列预测问题,特别是在多变量时间序列中,不同特征具有不同时间模式的情况。


    网站公告

    今日签到

    点亮在社区的每一天
    去签到