AF3 gdt函数解读

发布于:2025-02-15 ⋅ 阅读:(20) ⋅ 点赞:(0)

AlphaFold3的函数gdt、gdt_ts以及gdt_ha实现了 Global Distance Test (GDT) 评分计算,用于衡量蛋白质结构预测的准确性。GDT 评分衡量的是 预测结构(p1) 和 真实结构(p2) 之间的相似度,主要用于蛋白质结构比较。

源代码:

def gdt(p1, p2, mask, cutoffs):
    """
    Calculate the Global Distance Test (GDT) score for protein structures.

    Args:
        p1 (torch.Tensor): Coordinates of the first structure [..., N, 3].
        p2 (torch.Tensor): Coordinates of the second structure [..., N, 3].
        mask (torch.Tensor): Mask for valid residues [..., N].
        cutoffs (list): List of distance cutoffs for GDT calculation.

    Returns:
        torch.Tensor: GDT score [...].
    """
    # Ensure inputs are float
    p1 = p1.float()
    p2 = p2.float()
    mask = mask.float()

    # Calculate number of valid residues per batch
    n = torch.sum(mask, dim=-1)

    # Calculate pairwise distances
    distances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1))

    scores = []
    for c in cutoffs:
        # Calculate score for each cutoff, accounting for the mask
        score = torch.sum((distances <= c).float() * mask, dim=-1) / (n + 1e-8)
        scores.append(score)

    # Stack scores and average across cutoffs
    scores = torch.stack(scores, dim=-1)
    return torch.mean(scores, dim=-1)


def gdt_ts(p1, p2, mask):
    """
    Calculate the Global Distance Test Total Score (GDT_TS).

    Args:
        p1 (torch.Tensor): Coordinates of the first structure [..., N, 3].
        p2 (torch.Tensor): Coordinates of the second structure [..., N, 3].
        mask (torch.Tensor): Mask for valid residues [..., N].

    Returns:
        torch.Tensor: GDT_TS score [...].
    """
    return gdt(p1, p2, mask, [1., 2., 4., 8.])


def gdt_ha(p1, p2, mask):
    """
    Calculate the Global Distance Test High Accuracy (GDT_HA) score.

    Args:
        p1 (torch.Tensor): Coo