SamOut 推理空间不变模型解析

发布于:2024-12-20 ⋅ 阅读:(186) ⋅ 点赞:(0)

项目地址

SamOutV2 0.18B模型

  • 采取 em参数共享在参数量减半的情况下将维度从1024 拉升到了1536
  • sft 单论对话 loss 保持1.8
  • 如果未来匹配state 推理代码性能不变的同时推理任意长度使用资源空间保持不变
import torch


class MaxState(torch.nn.Module):
    def __init__(self, hidden_dim, heads):
        super(MaxState, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)

        self.head_num = heads

        self.hidden = hidden_dim

    def forward(self, input_data, state=None):
        b, s, k, h = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size

        out = self.head0(input_data)

        out1 = self.head1(input_data)

        out2 = self.head2(input_data)

        out = out.reshape([b, s, k, h]).permute([0, 2, 1, 3])
        out1 = out1.reshape([b, s, k, h]).permute([0, 2, 1, 3])
        if state is None:
            out = torch.cummax((out + out1) / h ** 0.5, 2)[0]

        else:
            out = torch.cummax(torch.concat([state, (out + out1)/ h ** 0.5], 2) , 2)[0]
        state1 = out[:, :, -1:]

        out = out.permute([0, 2, 1, 3])
        out1 = out1.permute([0, 2, 1, 3])
        if state is None:
            out = out.reshape([b, s, -1])
            out1 = out1.reshape([b, s, -1])
        else:

            out = out[:, -1:].reshape([b, 1, -1])
            out1 = out1[:, -1:].reshape([b, 1, -1])

        out = (out + out2) * out + out1

        return out, state1


class FeedForward(torch.nn.Module):
    def __init__(self, hidden_size):
        super(FeedForward, self).__init__()

        self.ffn1 = torch.nn.Linear(hidden_size, hidden_size // 2)
        self.ffn2 = torch.nn.Linear(hidden_size // 2, hidden_size)
        self.gate = torch.nn.Linear(hidden_size, hidden_size // 2)

        self.relu = torch.nn.ReLU()
        self.gr = torch.nn.Dropout(0.1)

    def forward(self, x):
        x1 = self.ffn1(x)
        x2 = self.relu(self.gate(x))
        xx = x1 * x2
        x = self.gr(self.ffn2(xx))
        return x


class DecoderLayer(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(DecoderLayer, self).__init__()

        self.self_attention = MaxState(hidden_size, num_heads)

        self.ffn = FeedForward(hidden_size)
        self.layer_norm = torch.nn.LayerNorm(hidden_size)

        self.alpha = torch.nn.Parameter(torch.tensor(0.5))

    def forward(self, x, state=None, ):
        x1, state = self.self_attention(x, state)
        x = self.layer_norm(self.alpha * self.ffn(x1) + (1 - self.alpha) * x)

        return x, state


class SamOut(torch.nn.Module):
    def __init__(self, voc_size, hidden_size, num_heads, num_layers):
        super(SamOut, self).__init__()
        self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)

        self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
        self.head = FeedForward(hidden_size)

    def state_forward(self, state, x):
        if state is None:
            state = [None] * len(self.decoder_layers)
        i = 0
        for ii, decoder_layer in enumerate(self.decoder_layers):
            x1, state[i] = decoder_layer(x, state[i])
            x = x1 + x
            i += 1
        return x, state

    def forward(self, x, state=None):
        x = self.em(x)

        x, state = self.state_forward(state, x)

        em = self.head(self.em.weight) / x.shape[-1]

        return x @ em.permute([1, 0]), state


if __name__ == '__main__':
    net = SamOut(235, 256, 16, 4)
    net(torch.randint(0, 200, [2, 8 * 13]))

解析该代码为什么推理空间不变,需要深入理解其设计和实现细节。首先,让我们逐步分析这段代码中的各个组件及其工作原理。

MaxState 类

MaxState 类定义了一个多头机制(multi-head mechanism),它可能是为了模仿Transformer架构中的多头注意力机制。这个类中定义了三个线性变换层 (self.head0, self.head1, 和 self.head2) 来处理输入数据,并通过累积最大值操作 (torch.cummax) 来更新状态。这种设计确保了在每次前向传播时,模型能够保留到目前为止遇到的最大值,从而保持了一种记忆效应。这有助于维持推理过程中的上下文信息而不改变空间维度。

class MaxState(torch.nn.Module):
    def __init__(self, hidden_dim, heads):
        ...
    
    def forward(self, input_data, state=None):
        ...
        if state is None:
            out = torch.cummax((out + out1) / h ** 0.5, 2)[0]
        else:
            out = torch.cummax(torch.concat([state, (out + out1)/ h ** 0.5], 2) , 2)[0]
        ...

FeedForward 类

FeedForward 类实现了简单的前馈神经网络(FFN),它包括两个线性层以及门控机制来控制信息流。此模块负责执行空间变换,但它的输出形状与输入相同,因此不会改变特征图的空间尺寸。

class FeedForward(torch.nn.Module):
    def __init__(self, hidden_size):
        ...
    
    def forward(self, x):
        ...
        return x

DecoderLayer 类

DecoderLayer 类结合了自注意力机制(由 MaxState 实现)和前馈网络(FeedForward)。这里引入了一个可学习参数 alpha 来调整来自这两个子模块的信息混合比例。重要的是,层归一化(LayerNorm)被应用于最终输出之前,以稳定训练并帮助梯度流动。此外,由于 MaxState 的设计保证了状态的持续更新而没有改变序列长度或隐藏维度,所以整个解码器层也不会改变推理空间。

class DecoderLayer(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        ...
    
    def forward(self, x, state=None, ):
        ...
        return x, state

SamOut 类

SamOut 类是整个模型的核心,它包含了嵌入层、多个解码器层组成的列表(ModuleList),以及一个最终的前馈网络用于生成预测结果。关键点在于:

  • Embedding 层:将词汇索引转换为密集向量表示。
  • 解码器层堆叠:通过循环调用每个解码器层来进行多次迭代处理。每层都接收当前状态作为输入,并返回更新后的状态。
  • 最终投影:使用嵌入矩阵转置乘以前一层输出的方式计算logits。这种方式确保了输出维度与词汇表大小相匹配,同时保持了输入序列的时间步数不变。
class SamOut(torch.nn.Module):
    def __init__(self, voc_size, hidden_size, num_heads, num_layers):
        ...
    
    def state_forward(self, state, x):
        ...
        return x, state
    
    def forward(self, x, state=None):
        ...
        return x @ em.permute([1, 0]), state

推理空间不变的原因

综上所述,SamOut 模型之所以能够在推理过程中保持空间不变,是因为从输入到输出的所有操作都被精心设计以保持时间序列的长度一致。无论是通过自定义的 MaxState 进行状态更新还是通过标准的 FeedForward 进行非线性变换,这些操作都不会减少或增加时间步的数量。即使在应用了多层解码器之后,序列的长度依然保持不变,只是特征表示得到了增强或转换。最后,在生成预测时,模型采用了基于嵌入矩阵的操作,这也保证了输出的概率分布对应于原始输入序列的每一个位置,因此推理空间在整个过程中是稳定的。

这种特性对于诸如语言模型等应用场景非常重要,因为它们通常需要根据给定的历史信息逐个生成新词,而不能随意改变文本序列的结构。通过这种方式,模型可以在不破坏原有顺序的情况下有效地捕捉长期依赖关系,并且在生成过程中逐步构建合理的句子或段落。