AF3 三角乘法更新介绍

发布于:2025-02-10 ⋅ 阅读:(56) ⋅ 点赞:(0)

AlphaFold 模型在蛋白质结构预测中引入了两种重要机制来处理 成对表示张量(pair representation):三角形乘法更新(Triangle Multiplicative Update)和 三角形自我注意(Triangle Self-Attention)。

三角形乘法更新旨在通过三角形关系传播信息。这种机制模拟了蛋白质残基之间通过共同邻居(第三个残基)的间接相互作用。

三角形乘法更新图解

三角形乘法更新算法

AlphaFold3中 三角乘法更新是通过 TriangleMultiplicationOutgoing 和TriangleMultiplicationIncoming类实现的。

TriangleMultiplicationOutgoing / TriangleMultiplicationIncoming <-  TriangleMultiplicativeUpdate  <- BaseTriangleMultiplicativeUpdate

源代码:

from functools import partialmethod
from typing import Optional
from abc import ABC, abstractmethod

import torch
import torch.nn as nn
from torch.nn import LayerNorm
from src.models.components.primitives import Linear
from src.utils.chunk_utils import chunk_layer
from src.utils.tensor_utils import add, permute_final_dims


class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
    """
    Implements Algorithms 11 and 12.
    """

    @abstractmethod
    def __init__(self, c_z, c_hidden, _outgoing):
        """
        Args:
            c_z:
                Input channel dimension
            c:
                Hidden channel dimension
        """
        super(BaseTriangleMultiplicativeUpdate, self).__init__()
        self.c_z = c_z
        self.c_hidden = c_hidden
        self._outgoing = _outgoing

        self.linear_g = Linear(self.c_z, self.c_z, init="gating")
        self.linear_z = Linear(self.c_hidden, self.c_z, init="final")

        self.layer_norm_in = LayerNorm(self.c_z)
        self.layer_norm_out = LayerNorm(self.c_hidden)

        self.sigmoid = nn.Sigmoid()

    def _combine_projections(self,
                             a: torch.Tensor,
                             b: torch.Tensor,
                             _inplace_chunk_size: Optional[int] = None
                             ) -> torch.Tensor:
        if self._outgoing:
            a = permute_final_dims(a, (2, 0, 1))
            b = permute_final_dims(b, (2, 1, 0))
        else:
            a = permute_final_dims(a, (2, 1, 0))
            b = permute_final_dims(b, (2, 0, 1))

        if _inplace_chunk_size is not None:
            # To be replaced by torch vmap
            for i in range(0, a.shape[-3], _inplace_chunk_size):
                a_chunk = a[..., i: i + _inplace_chunk_size, :, :]
                b_chunk = b[..., i: i + _inplace_chunk_size, :, :]
                a[..., i: i + _inplace_chunk_size, :, :] = (
                    torch.matmul(
                        a_chunk,
                        b_chunk,
                    )
                )

            p = a
        else:
            p = torch.matmul(a, b)

        return permute_final_dims(p, (1, 2, 0))

    @abstractmethod
    def forward(self,
                z: torch.Tensor,
                mask: Optional[torch.Tensor] = None,
                inplace_safe: bool = False,
                _add_with_inplace: bool = False
                ) -> torch.Tensor:
        """
        Args:
            x:
                [*, N_res, N_res, C_z] x tensor
            mask:
                [*, N_res, N_res] x mask
        Returns:
            [*, N_res, N_res, C_z] output tensor
        """
        pass


class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
    """
    Implements Algorithms 11 and 12.
    """

    def __init__(self, c_z, c_hidden, _outgoing=True):
        """
        Args:
            c_z:
                Input channel dimension
            c:
                Hidden channel dimension
        """
        super(TriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
                                                           c_hidden=c_hidden,
                                                           _outgoing=_outgoing)

        self.linear_a_p = Linear(se

网站公告

今日签到

点亮在社区的每一天
去签到