GPT - 因果掩码(Causal Mask)

发布于:2025-04-12 ⋅ 阅读:(40) ⋅ 点赞:(0)

本节代码定义了一个函数 causal_mask,用于生成因果掩码(Causal Mask)。因果掩码通常用于自注意力机制中,以确保模型在解码时只能看到当前及之前的位置,而不能看到未来的信息。这种掩码在自然语言处理任务(如语言生成)中非常重要,因为它模拟了人类阅读或写作时的顺序性。

一、因果掩码(Causal Mask)代码实现

def causal_mask(x):
    mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1) == 0
    return mask
1. 输入参数
  • x:输入张量,通常是一个序列,形状为 (seq_len, d_model)(batch_size, seq_len, d_model)。这里的 seq_len 是序列的长度。

2. 生成掩码
mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1) == 0
  • torch.ones(x.shape[0], x.shape[0]):生成一个形状为 (seq_len, seq_len) 的全1矩阵。

  • torch.triu(..., diagonal=1):取该矩阵的上三角部分(包括对角线),其余部分设置为0。diagonal=1 表示从对角线的下一个位置开始取上三角部分。

  • == 0:将上三角部分(包括对角线)的值设置为 False,其余部分设置为 True。这样生成的掩码矩阵中,True 表示需要保留的注意力位置,False 表示需要被忽略的注意力位置。

3. 返回值
  • mask:生成的因果掩码,形状为 (seq_len, seq_len),是一个布尔张量。

示例

假设输入张量 x 的形状为 (5, d_model),即序列长度为5。那么:

x = torch.randn(5, d_model)  # 示例输入
mask = causal_mask(x)
print(mask)

输出的掩码矩阵 mask 将是:

tensor([[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]])

作用

在自注意力机制中,因果掩码用于确保模型在计算注意力分数时,只能看到当前及之前的位置,而不能看到未来的信息。具体来说:

  • True:表示可以计算注意力分数。

  • False:表示需要被忽略,注意力分数会被设置为一个非常小的值(如 -1e9),从而在 softmax 归一化后,其权重趋近于0。

二、因果掩码如何使用?

1. 因果掩码的生成

因果掩码的生成函数如下:

def causal_mask(x):
    mask = torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1) == 0
    return mask
  • 输入x 是一个张量,通常是一个序列的嵌入表示,形状为 (seq_len, d_model)(batch_size, seq_len, d_model)

  • 输出:生成一个布尔张量 mask,形状为 (seq_len, seq_len),其中上三角部分(包括对角线)为 True,其余部分为 False

2. 因果掩码的应用

因果掩码在 Poetry 数据集类中被应用,具体如下:

class Poetry(Dataset):
    def __init__(self, poetries, tokenizer: Tokenizer):
        self.poetries = poetries
        self.tokenizer = tokenizer

        self.pad_id = self.tokenizer.vocab["[PAD]"]
        self.bos_id = self.tokenizer.vocab["[BOS]"]
        self.eos_id = self.tokenizer.vocab["[EOS]"]

    def __len__(self):
        return len(self.poetries)
    
    def __getitem__(self, idx):
        poetry = self.poetries[idx]
        poetry_ids = self.tokenizer.encode(poetry)
        input_ids = torch.tensor([self.bos_id] + poetry_ids)
        input_msk = causal_mask(input_ids)
        label_ids = torch.tensor(poetry_ids + [self.eos_id])
        return {
            "input_ids": input_ids,
            "input_msk": input_msk,
            "label_ids": label_ids
        }
  • __getitem__ 方法

    • 对于每首诗 poetry,将其编码为 poetry_ids

    • 在输入序列的开头添加 [BOS](开始标记符),生成 input_ids

    • 使用 causal_mask 函数生成因果掩码 input_msk

    • 在标签序列的末尾添加 [EOS](结束标记符),生成 label_ids

3. 因果掩码的传递

在训练过程中,因果掩码 input_msk 会被传递给模型的自注意力层。具体如下:

for epoch in range(epochs):
    for batch in tqdm(trainloader, desc="Training"):
        batch_input_ids = batch["input_ids"]
        batch_input_msk = batch["input_msk"]
        batch_label_ids = batch["label_ids"]

        output = model(batch_input_ids, batch_input_msk)
        loss = loss_fn(output.view(-1, len(vocab)), batch_label_ids.view(-1))
        loss.backward()
        optim.step()
        optim.zero_grad()
  • model(batch_input_ids, batch_input_msk)

    • batch_input_ids 是输入序列的嵌入表示。

    • batch_input_msk 是对应的因果掩码。

    • 模型在计算自注意力时,会使用 batch_input_msk 来确保解码器只能看到当前及之前的位置。

4. 因果掩码的作用

MultiHeadAttention 类中,因果掩码被应用到注意力分数矩阵中:

if attn_mask is not None:
    attn_mask = attn_mask.unsqueeze(1)
    atten_scores = atten_scores.masked_fill(attn_mask == 0, -1e9)
  • attn_mask.unsqueeze(1)

    • 将掩码的形状从 (batch_size, seq_len, seq_len) 扩展为 (batch_size, 1, seq_len, seq_len)

  • masked_fill

    • 将掩码中为 False 的位置的注意力分数设置为 -1e9,确保这些位置的注意力权重趋近于0。

5. 生成诗歌时的因果掩码

在生成诗歌时,因果掩码同样被应用:

def generate_poetry(method="greedy", top_k=5):
    model.eval()
    with torch.no_grad():
        input_ids = torch.tensor(vocab["[BOS]"]).view(1, -1)

        while input_ids.shape[1] < seq_len:
            output = model(input_ids, None)
            probabilities = torch.softmax(output[:, -1, :], dim=-1)
            
            if method == "greedy":
                next_token_id = torch.argmax(probabilities, dim=-1)
            elif method == "top_k":
                top_k_probs, top_k_indices = torch.topk(probabilities[0], top_k)
                next_token_id = top_k_indices[torch.multinomial(top_k_probs, 1)]

            if next_token_id == vocab["[EOS]"]:
                break

            input_ids = torch.cat([input_ids, next_token_id.view(1, 1)], dim=1)
    return input_ids.squeeze()
  • model(input_ids, None)

    • 在生成诗歌时,输入序列 input_ids 会逐渐增长,但因果掩码是隐含的,因为模型的自注意力层会自动处理序列的顺序性。

    • 生成过程中,模型只能看到当前及之前的位置,这与训练时使用因果掩码的目的相同。



 


网站公告

今日签到

点亮在社区的每一天
去签到