DPFunc蛋白质功能预测模型复现报告

发布于:2025-04-07 ⋅ 阅读:(31) ⋅ 点赞:(0)

模型简介

模型的具体介绍见蛋白质功能预测论文阅读记录2025(DPFunc、ProtCLIP)_protein functions-CSDN博客

复现流程

仓库:CSUBioGroup/DPFunc

时间:2025.4.5

环境配置

python 3.9.21 & CUDA 11.6

Pytorch: 1.12.0

DGL: 1.1.0(需要安装cuda版本)

download wheels in this page https://data.dgl.ai/wheels/cu116/repo.html and use

pip install 'dgl-1.1.0-cp39-cp39-manylinux1_x86_64.whl'

一个报错的解决:

OSError: libcusparse.so.11: cannot open shared object file: No such file or directory

在需要使用python的终端中运行下述语句即可

export LD_LIBRARY_PATH="~/miniconda3/envs/xxxxx/lib:$LD_LIBRARY_PATH

数据集准备

首先选好自己的蛋白质集合

先下载pdb和interpro的数据,有数据的就可以跳过了。

下载pdb代码,直接用wget就可以:

import os
import subprocess
from time import sleep
st = set()
file_list = os.listdir("pdb")
for filename in file_list:
    if filename.startswith("AF"):
        st.add(filename.split("-")[1])

urls = []
# pids.txt保存的是需要下载pdb文件的蛋白质编号(uniprot)
with open(os.path.join("pids.txt"),"r") as f:
    lines = f.readlines()
    for line in lines:
        pid = line.strip()
        url = "https://alphafold.ebi.ac.uk/files/AF-"+pid+"-F1-model_v4.pdb\n"
        urls.append(url)

for i,url in enumerate(urls):
    if i < len(file_list):
        continue
    pid = url.split("/")[-1].split("-")[1]
    if pid not in st:
        print(pid)
        output_file_path = os.path.join("pdb","AF-"+pid+"-F1-model_v4.pdb")
        try:
            result = subprocess.run(
                ["wget", "-O", output_file_path, url.strip()],
                check=True,  # 如果返回值不为 0,将引发异常
                stdout=subprocess.PIPE,  # 捕获标准输出
                stderr=subprocess.PIPE   # 捕获标准错误
            )
            print(f"文件下载成功,返回值:{result.returncode}")
        except subprocess.CalledProcessError as e:
            print(f"下载失败,返回值:{e.returncode}")
            print(f"错误信息:{e.stderr.decode()}")
            os.system("rm %s"%(output_file_path))
        sleep(1)

下载interpro文件的代码:

'''
修改自InterPro官网上的代码
用于读取InterPro上的查找结果 export.tsv 并根据结果下载所有蛋白质的结构域信息
'''
# standard library modules
import sys, errno, re, json, ssl, os
from urllib import request
from urllib.error import HTTPError
from time import sleep

def parse_items(items):
    if type(items)==list:
        return ",".join(items)
    return ""
def parse_member_databases(dbs):
    if type(dbs)==dict:
        return ";".join([f"{db}:{','.join(dbs[db])}" for db in dbs.keys()])
    return ""
def parse_go_terms(gos):
    if type(gos)==list:
        return ",".join([go["identifier"] for go in gos])
    return ""
def parse_locations(locations):
    if type(locations)==list:
        return ",".join(
        [",".join([f"{fragment['start']}..{fragment['end']}" 
                    for fragment in location["fragments"]
                    ])
        for location in locations
        ])
    return ""
def parse_group_column(values, selector):
    return ",".join([parse_column(value, selector) for value in values])

def parse_column(value, selector):
    if value is None:
        return ""
    elif "member_databases" in selector:
        return parse_member_databases(value)
    elif "go_terms" in selector: 
        return parse_go_terms(value)
    elif "children" in selector: 
        return parse_items(value)
    elif "locations" in selector:
        return parse_locations(value)
    return str(value)

def download_to_file(url, file_path):
    #disable SSL verification to avoid config issues
    context = ssl._create_unverified_context()

    next = url
    last_page = False

    
    attempts = 0
    while next:
        try:
            req = request.Request(next, headers={"Accept": "application/json"})
            res = request.urlopen(req, context=context)
            # If the API times out due a long running query
            if res.status == 408:
                # wait just over a minute
                sleep(61)
                # then continue this loop with the same URL
                continue
            elif res.status == 204:
                #no data so leave loop
                break
            payload = json.loads(res.read().decode())
            next = payload["next"]
            attempts = 0
            if not next:
                last_page = True
        except HTTPError as e:
            if e.code == 408:
                sleep(61)
                continue
            else:
                # If there is a different HTTP error, it wil re-try 3 times before failing
                if attempts < 3:
                    attempts += 1
                    sleep(61)
                    continue
                else:
                    sys.stderr.write("LAST URL: " + next)
                    raise e
        with open(file_path,"w+") as f:
            for i, item in enumerate(payload["results"]):
                f.write(parse_column(item["metadata"]["accession"], 'metadata.accession') + "\t")
                f.write(parse_column(item["metadata"]["name"], 'metadata.name') + "\t")
                f.write(parse_column(item["metadata"]["source_database"], 'metadata.source_database') + "\t")
                f.write(parse_column(item["metadata"]["type"], 'metadata.type') + "\t")
                f.write(parse_column(item["metadata"]["integrated"], 'metadata.integrated') + "\t")
                f.write(parse_column(item["metadata"]["member_databases"], 'metadata.member_databases') + "\t")
                f.write(parse_column(item["metadata"]["go_terms"], 'metadata.go_terms') + "\t")
                f.write(parse_column(item["proteins"][0]["accession"], 'proteins[0].accession') + "\t")
                f.write(parse_column(item["proteins"][0]["protein_length"], 'proteins[0].protein_length') + "\t")
                f.write(parse_column(item["proteins"][0]["entry_protein_locations"], 'proteins[0].entry_protein_locations') + "\t")
                f.write("\n")
        
    # Don't overload the server, give it time before asking for more
    sleep(1)

# 先读取之前已经完成的进度
exist_file_list = set()
already_exist_file = os.listdir("interpro")
for file in already_exist_file:
    if file.endswith(".tsv"):
        exist_file_list.add(file.split(".")[0])
        
with open(os.path.join("pids.txt"),"r") as f:
    lines = f.readlines()
cnt = 0
for line in lines:
    cnt+=1
    pid = line.strip()
    # 如果之前完成了,就跳过该条信息
    # if cnt <= 47:
    #     continue
    if pid in exist_file_list:
        print(pid, " exists")
        continue
    print(line,"  ",cnt,"/",len(lines))
    url = f"https://www.ebi.ac.uk:443/interpro/api/entry/InterPro/protein/reviewed/{pid}/?page_size=200"
    download_to_file(url,os.path.join("interpro", pid+'.tsv'))

使用ESM-1b-650M跑embedding,并取第31层的embedding进行保留,注意这里不能直接做sum或者mean的readout,需要整个序列的embedding都保存。

ESM1b代码:(直接更换esm.pretrained后面的模型就可以跑其他的esm模型了)

'''
用于生成序列中每个氨基酸的特征向量
'''
import json
import os
import pickle
import tqdm
import pandas as pd
from pandas import DataFrame as df

import torch
import esm

# 设置CUDA设备编号
device = "cuda:2"
# ESM2预训练模型初始化,取消梯度
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
converter = alphabet.get_batch_converter()
model.to(device)
for p in model.parameters():
    p.requires_grad = False
model.eval()

species_name = "xxxx"
# 读取预处理后的文件,这个不重要,只需要根据protein的id来提供序列就可以了
with open(os.path.join(species_name+"_protein_info.json"),'r') as f:
    protein_info = json.load(f)

print("start ESM1b, get sequence embeddings!")
cnt = 0
for uid in protein_info.keys():
    cnt += 1
    print(cnt,'/',len(protein_info))
    val = protein_info[uid]
    labels, strs, tokens = converter([(uid,val["seq"])])
    tokens = tokens.to(device)
    seq_emb = model(tokens, repr_layers=[31])['representations'][31]
    seq_emb = seq_emb.squeeze(0)
    seq_emb = seq_emb.cpu().numpy()
    
    dir_path = os.path.join("dataset",species_name,uid)
    os.makedirs(dir_path, exist_ok=True)
    with open(os.path.join(dir_path,"seq_emb.pkl"), 'wb') as f:
        pickle.dump(seq_emb, f)
    del tokens,seq_emb
    torch.cuda.empty_cache()

根据仓库中的需要,导出{bp/cc/mf}_{train/valid/test}_pid_list.pkl和pid2esm.pkl。

这个太简单就不给代码了。

数据处理梳理

这一部分踩了很多坑,改了很多代码才调通。

首先generate_points是对所有的数据进行处理的,处理完之后大家都可以读取,所以这里的pid_list.pkl需要包含所有的蛋白质编号。

process_graph需要单独对每一个{bp/cc/mf}_{train/valid/test}_pid_list.pkl做,也就是做出来至少9个文件。大概就是下图所示:

process_interpro,这个可以不用跑,因为在后面的main和pred中它还可以能单独跑,但如果跑了,就需要在yaml配置里面的base填写interpro_whole的路径。

这里因为我们下载下来的interpro是tsv文件,如果直接处理成一行01序列onehot编码会占用很多空间,魔改一下data_utils.py中的get_inter_whole_data函数,让他通过读取标签序列来初始化稀疏矩阵。

def get_inter_whole_data(pid_list, save_file, domain_map, protein_info):
    rows = []
    cols = []
    data = []
    for i in trange(len(pid_list)):
        pid = pid_list[i]
        if pid in protein_info:
            domain_list = protein_info[pid]['domain']
            for domain in domain_list:
                if domain[0] not in domain_map:
                    continue
                domain_id = domain_map[domain[0]]
                rows += [i]
                cols += [domain_id]
                data += [1]
    
    # col_nodes = np.max(cols) + 1
    col_nodes = 22369
    interpro_matrix = csr_matrix((data, (rows, cols)), shape=(len(pid_list), col_nodes))
    with open(save_file, 'wb') as fw:
        pkl.dump(interpro_matrix, fw)
    
    print(interpro_matrix.shape)
    return interpro_matrix

process_structure,这个看起来是作者自用的处理代码,里面都是本地路径,不用管。

另外一个需要注意的事项,提供给DPFunc的go.txt必须是经过传播的go集合,否则训练效果会很差。

跑代码

首先运行:(注意填写正确的esm和pdb的路径)

python DataProcess/generate_points.py -i data/pid_list.pkl -o data/pid_points.pkl

如果只跑test,使用以下流程:

配置填写:(以bp为例)

name: bp
mlb: ./mlb/bp_go.mlb
results: ./results

base:
  pdb_points: ./data/pid_points.pkl
test:
  name: test
  pid_list_file: ./data/pid_list.pkl
  pid_go_file: ./data/bp_go.txt
  pid_pdb_file: ./data/PDB/bp_test_whole_pdb_part0.pkl
  interpro_file: ./data/bp_interpro.pkl

注意这里跑之前需要把process_graph代码中的输出路径改好

python DataProcess/process_graph.py -d bp

python DataProcess/process_graph.py -d cc

python DataProcess/process_graph.py -d mf

下载已有的DPFunc权重https://drive.google.com/file/d/1V0VTFTiB29ilbAIOZn0okBQWPlbOI3wN/view?usp=drive_link

解压到save_models之后运行预测代码

python DPFunc_pred.py -d bp -n 0 -p DPFunc_model

python DPFunc_pred.py -d cc -n 0 -p DPFunc_model

python DPFunc_pred.py -d mf -n 0 -p DPFunc_model

结果:

bp: Fmax 0.7429, AUPR 0.7608

cc: Fmax 0.8246, AUPR 0.8775

mf: Fmax 0.7238, AUPR 0.7607

跑train+test,使用以下流程:

先要分离train/valid/test的pid_list

填写配置:

name: bp
mlb: ./mlb/bp_go.mlb
results: ./results

base:
  interpro_whole: 
  pdb_points: ./data/pid_points.pkl

train:
  name: train
  pid_list_file: ./data/bp_train_pid_list.pkl
  pid_go_file: ./data/bp_train_go.txt
  pid_pdb_file: ./data/PDB/bp_train_whole_pdb_part{}.pkl
  train_file_count: 1
  interpro_file: ./data/bp_train_interpro.pkl

valid:
  name: valid
  pid_list_file: ./data/bp_valid_pid_list.pkl
  pid_go_file: ./data/bp_valid_go.txt
  pid_pdb_file: ./data/PDB/bp_valid_whole_pdb_part0.pkl
  interpro_file: ./data/bp_valid_interpro.pkl

test:
  name: test
  pid_list_file: ./data/bp_test_pid_list.pkl
  pid_go_file: ./data/bp_test_go.txt
  pid_pdb_file: ./data/PDB/bp_test_whole_pdb_part0.pkl
  interpro_file: ./data/bp_test_interpro.pkl

然后还是需要改process_graph的代码,让他把train/valid/test三个都分别读取对应的文件再输出到对应的路径

重新进行process_graph.py

python DPFunc_main.py -d bp -n 0 -e 15 -p temp_model

python DPFunc_main.py -d cc -n 0 -e 15 -p temp_model

python DPFunc_main.py -d mf -n 0 -e 15 -p temp_model

验证集结果:

bp: AUC 0.9989, Fmax 0.5867, AUPR 0.5913, cut-off: 0.55

cc: AUC 0.9990, Fmax 0.7179, AUPR 0.7711, cut-off: 0.58

mf: AUC 0.9992, Fmax 0.7500, AUPR 0.7826, cut-off: 0.45

python DPFunc_pred.py -d bp -n 0 -p temp_model

python DPFunc_pred.py -d cc -n 0 -p temp_model

python DPFunc_pred.py -d mf -n 0 -p temp_model

测试集结果:

bp: AUC 0.9990, Fmax 0.5867, AUPR 0.5997, cut-off: 0.56

cc: AUC 0.9993, Fmax 0.7238, AUPR 0.7690, cut-off: 0.72

mf: AUC 0.9992, Fmax 0.7500, AUPR 0.7839, cut-off: 0.45

吐槽

整个复现花了5~6天的时间,大部分时间都在数据处理上,模型的效果确实很好,训练也很快,但是中途有一些不太好的设计。

1、冗余的dgl图特征,仅仅只是把esm的特征保存在"x"里面,然后用的时候又是拿出来跑网络,这样一份esm特征我们就需要保存两遍,很占空间。一般没有自定义传播函数的时候,建议dgl图就只保存框架,特征单独拿出来放,大部分DGL自带的GCN都是分离节点特征输入的。

2、路径修改繁琐,几个数据处理文件里面,既有读取配置文件的路径,又有写死的本地路径,改一个路径输入输出需要非常仔细才能不出错,有时候改了这里就忘了改那里。

3、函数不统一,训练代码用的读取函数和输出处理函数的名称相同,但是内容略有差别,调用的是不同位置的函数,改完数据处理函数发现没用,调了半天才发现训练里面的数据处理又是独立的。

4、计算Fmax指标时,有一个对预测结果进行传递的过程:

这个非常慢,每一个GO标签都要更新它的传递闭包,并且它是一个一个蛋白质计算的,时间复杂度大概是O(蛋白质个数 * 标签总数 * 平均每个GO标签的传递闭包大小)。

优化的话,可以先求GO图的拓扑序,按照GO拓扑序更新,从具体的GO标签更新到抽象的顶层GO标签,然后预测出来的结果可以作为一个batch更新,而不是一个一个蛋白质更新。

不过我现在还没来得及优化这个,有空再说吧