位置编码
elif self.dec_way == "pmf":
if self.channel_id:
# m,d//2 -> 1,m,d//2 -> c,m,d//2
# c,d//2 -> c,1,d//2 -> c,m,d//2
# c,m,d -> cm,1,d -> bcm, 1, d
pos_emb = torch.cat([
self.pos_emb.unsqueeze(0).repeat(self.enc_in, 1, 1),
self.channel_emb.unsqueeze(1).repeat(1, self.seg_num_y, 1)
], dim=-1).view(-1, 1, self.d_model).repeat(batch_size,1,1)
┌─────────────┐ ┌─────────────┐
│ pos_emb │ │ channel_emb │
│ [2, 256] │ │ [7, 256] │
└──────┬──────┘ └──────┬──────┘
│ │
▼ ▼
┌─────────────┐ ┌─────────────┐
│ unsqueeze(0)│ │unsqueeze(1) │
│ [1,2,256] │ │ [7,1,256] │
└──────┬──────┘ └──────┬──────┘
│ │
▼ ▼
┌─────────────┐ ┌─────────────┐
│ repeat │ │ repeat │
│ [7,2,256] │ │ [7,2,256] │
└──────┬──────┘ └──────┬──────┘
│ │
└────────────┬───────────────┘
│
▼
┌───────────────┐
│ concat(dim=-1)│
│ [7,2,512] │
└───────┬───────┘
│
▼
┌───────────────┐
│ view(-1,1,512)│
│ [14,1,512] │
└───────┬───────┘
│
▼
┌───────────────┐
│ repeat │
│ [224,1,512] │
└───────┬───────┘
│
▼
最终组合嵌入
理解时间序列数据的训练集、序列长度和批次大小
我又有了一个新的问题,训练集大小是 593 个样本,怎么 batchsize=16,seqlen=60
用自己的话,593 个样本,随机选择了 16 个,这16 个样本的时间步并不一定是连续的,但是一个 batch 内部封装的 60 个时间步一定是连续的
shuffle=True