AlphaFold 模型在蛋白质结构预测中引入了两种重要机制来处理 成对表示张量(pair representation):三角形乘法更新(Triangle Multiplicative Update)和 三角形自我注意(Triangle Self-Attention)。
三角形乘法更新旨在通过三角形关系传播信息。这种机制模拟了蛋白质残基之间通过共同邻居(第三个残基)的间接相互作用。
三角形乘法更新图解
三角形乘法更新算法
AlphaFold3中 三角乘法更新是通过 TriangleMultiplicationOutgoing 和TriangleMultiplicationIncoming类实现的。
TriangleMultiplicationOutgoing / TriangleMultiplicationIncoming <- TriangleMultiplicativeUpdate <- BaseTriangleMultiplicativeUpdate
源代码:
from functools import partialmethod
from typing import Optional
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from torch.nn import LayerNorm
from src.models.components.primitives import Linear
from src.utils.chunk_utils import chunk_layer
from src.utils.tensor_utils import add, permute_final_dims
class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
"""
Implements Algorithms 11 and 12.
"""
@abstractmethod
def __init__(self, c_z, c_hidden, _outgoing):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(BaseTriangleMultiplicativeUpdate, self).__init__()
self.c_z = c_z
self.c_hidden = c_hidden
self._outgoing = _outgoing
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
self.layer_norm_in = LayerNorm(self.c_z)
self.layer_norm_out = LayerNorm(self.c_hidden)
self.sigmoid = nn.Sigmoid()
def _combine_projections(self,
a: torch.Tensor,
b: torch.Tensor,
_inplace_chunk_size: Optional[int] = None
) -> torch.Tensor:
if self._outgoing:
a = permute_final_dims(a, (2, 0, 1))
b = permute_final_dims(b, (2, 1, 0))
else:
a = permute_final_dims(a, (2, 1, 0))
b = permute_final_dims(b, (2, 0, 1))
if _inplace_chunk_size is not None:
# To be replaced by torch vmap
for i in range(0, a.shape[-3], _inplace_chunk_size):
a_chunk = a[..., i: i + _inplace_chunk_size, :, :]
b_chunk = b[..., i: i + _inplace_chunk_size, :, :]
a[..., i: i + _inplace_chunk_size, :, :] = (
torch.matmul(
a_chunk,
b_chunk,
)
)
p = a
else:
p = torch.matmul(a, b)
return permute_final_dims(p, (1, 2, 0))
@abstractmethod
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
_add_with_inplace: bool = False
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] x tensor
mask:
[*, N_res, N_res] x mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
pass
class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(TriangleMultiplicativeUpdate, self).__init__(c_z=c_z,
c_hidden=c_hidden,
_outgoing=_outgoing)
self.linear_a_p = Linear(se