AF3 ProteinDataModule类解读

发布于:2025-04-16 ⋅ 阅读:(8) ⋅ 点赞:(0)

AlphaFold3 protein_datamodule 模块 ProteinDataModule 类继承自 PyTorch Lightning 数据模块(LightningDataModule),负责 ProteinFlow 数据的准备、加载、拆分、变换等逻辑封装在一起,便于训练过程中的统一管理和复现。

这个类承担了 AlphaFold3 训练和评估过程中的 数据准备、划分、转换、加载 四个核心任务:

任务 功能说明
数据准备 (prepare_data) 下载/准备 ProteinFlow 数据集(包含结构和注释)
数据集构建 (setup) 构建训练、验证、测试集,应用转换
数据加载器提供 (*_dataloader) 返回 PyTorch 的 DataLoader 供模型训练/验证/测试使用
数据增强与特征提取 应用了 CropperReorderAF3Featurizer,为模型生成输入特征

源代码:

class ProteinDataModule(LightningDataModule):
    """`LightningDataModule` for the Protein Data Bank.

    A `LightningDataModule` implements 7 key methods:

    ```python
        def prepare_data(self):
        # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
        # Download data, pre-process, split, save to disk, etc...

        def setup(self, stage):
        # Things to do on every process in DDP.
        # Load data, set variables, etc...

        def train_dataloader(self):
        # return train dataloader

        def val_dataloader(self):
        # return validation dataloader

        def test_dataloader(self):
        # return test dataloader

        def predict_dataloader(self):
        # return predict dataloader

        def teardown(self, stage):
        # Called on every process in DDP.
        # Clean up after fit or test.
    ```

    This allows you to share a full dataset without explaining how to download,
    split, transform and process the data.

    Read the docs:
        https://lightning.ai/docs/pytorch/latest/data/datamodule.html
    """

    def __init__(
            self,
            data_dir: str = "./data/",
            resolution_thr: float = 3.5,
            min_seq_id: float = 0.3,
            crop_size: int = 384,
            max_length: int = 10_000,
            use_fraction: float = 1.0,
            entry_type: str = "chain",
            classes_to_exclude: Optional[List[str]] = None,
            mask_residues: bool = False,
            lower_limit: int = 15,
            upper_limit: int = 100,
            mask_frac: Optional[float] = None,
            mask_sequential: bool = False,
            mask_whole_chains: bool = False,
            force_binding_sites_frac: float = 0.15,
            batch_size: int = 64,
            num_workers: int = 0,
            pin_memory: bool = False,
            debug: bool = False
    ) -> None:
        """Initialize a `ProteinDataModule`.

        :param resolution_thr: Resolution threshold for PDB structures
        :param min_seq_id: Minimum sequence identity for MMSeq2 clustering
        :param crop_size: The number of residues to crop the proteins to.
        :param max_length: Entries with total length of chains larger than max_length will be disregarded.
        :param use_fraction: the fraction of the clusters to use (first N in alphabetic order)
        :param entry_type: {"biounit", "chain", "pair"} the type of entries to generate ("biounit" for biounit-level
                            complexes, "chain" for chain-level, "pair" for chain-chain pairs (all pairs that are seen
                            in the same biounit and have intersecting coordinate clouds))
        :param classes_to_exclude: a list of classes to exclude from the dataset (select from "single_chains",
                                   "heteromers", "homomers")
        :param mask_residues: if True, the masked residues will be added to the output
        :param lower_limit: the lower limit of the number of residues to mask
        :param upper_limit: the upper limit of the number of residues to mask

网站公告

今日签到

点亮在社区的每一天
去签到