AF3 PairStack类源码解读

发布于:2025-02-10 ⋅ 阅读:(52) ⋅ 点赞:(0)

PairStack 是 AlphaFold 的核心模块之一,用于对残基对(residue-residue pair)的特征张量 z 进行迭代更新。这个模块结合几何操作(如三角形乘法)和注意力机制,逐步建模蛋白质序列中残基之间的复杂关系。

源代码:

class PairStack(nn.Module):
    def __init__(
            self,
            c_z: int,
            c_hidden_tri_mul: int = 128,
            c_hidden_pair_attn: int = 32,
            no_heads_tri_attn: int = 4,
            transition_n: int = 4,
            pair_dropout: float = 0.25,
            fuse_projection_weights: bool = False,
            inf: float = 1e8,
    ):
        super(PairStack, self).__init__()

        if fuse_projection_weights:
            self.tri_mul_out = FusedTriangleMultiplicationOutgoing(
                c_z,
                c_hidden_tri_mul,
            )
            self.tri_mul_in = FusedTriangleMultiplicationIncoming(
                c_z,
                c_hidden_tri_mul,
            )
        else:
            self.tri_mul_out = TriangleMultiplicationOutgoing(
                c_z,
                c_hidden_tri_mul,
            )
            self.tri_mul_in = TriangleMultiplicationIncoming(
                c_z,
                c_hidden_tri_mul,
            )

        self.tri_att_start = TriangleAttentionStartingNode(
            c_z,
            c_hidden_pair_attn,
            no_heads_tri_attn,
            inf=inf,
        )
        self.tri_att_end = TriangleAttentionEndingNode(
            c_z,
            c_hidden_pair_attn,
            no_heads_tri_attn,
            inf=inf,
        )

        self.transition = Transition(
            c_z,
            transition_n,
        )

        self.dropout_row_layer = DropoutRowwise(pair_dropout)
        self.dropout_col_layer = DropoutColumnwise(pair_dropout)

    def forward(
        

网站公告

今日签到

点亮在社区的每一天
去签到