模型简介
模型的具体介绍见蛋白质功能预测论文阅读记录2025(DPFunc、ProtCLIP)_protein functions-CSDN博客
复现流程
时间:2025.4.5
环境配置
python 3.9.21 & CUDA 11.6
Pytorch: 1.12.0
DGL: 1.1.0(需要安装cuda版本)
download wheels in this page https://data.dgl.ai/wheels/cu116/repo.html and use
pip install 'dgl-1.1.0-cp39-cp39-manylinux1_x86_64.whl'
一个报错的解决:
OSError: libcusparse.so.11: cannot open shared object file: No such file or directory
在需要使用python的终端中运行下述语句即可
export LD_LIBRARY_PATH="~/miniconda3/envs/xxxxx/lib:$LD_LIBRARY_PATH
数据集准备
首先选好自己的蛋白质集合
先下载pdb和interpro的数据,有数据的就可以跳过了。
下载pdb代码,直接用wget就可以:
import os
import subprocess
from time import sleep
st = set()
file_list = os.listdir("pdb")
for filename in file_list:
if filename.startswith("AF"):
st.add(filename.split("-")[1])
urls = []
# pids.txt保存的是需要下载pdb文件的蛋白质编号(uniprot)
with open(os.path.join("pids.txt"),"r") as f:
lines = f.readlines()
for line in lines:
pid = line.strip()
url = "https://alphafold.ebi.ac.uk/files/AF-"+pid+"-F1-model_v4.pdb\n"
urls.append(url)
for i,url in enumerate(urls):
if i < len(file_list):
continue
pid = url.split("/")[-1].split("-")[1]
if pid not in st:
print(pid)
output_file_path = os.path.join("pdb","AF-"+pid+"-F1-model_v4.pdb")
try:
result = subprocess.run(
["wget", "-O", output_file_path, url.strip()],
check=True, # 如果返回值不为 0,将引发异常
stdout=subprocess.PIPE, # 捕获标准输出
stderr=subprocess.PIPE # 捕获标准错误
)
print(f"文件下载成功,返回值:{result.returncode}")
except subprocess.CalledProcessError as e:
print(f"下载失败,返回值:{e.returncode}")
print(f"错误信息:{e.stderr.decode()}")
os.system("rm %s"%(output_file_path))
sleep(1)
下载interpro文件的代码:
'''
修改自InterPro官网上的代码
用于读取InterPro上的查找结果 export.tsv 并根据结果下载所有蛋白质的结构域信息
'''
# standard library modules
import sys, errno, re, json, ssl, os
from urllib import request
from urllib.error import HTTPError
from time import sleep
def parse_items(items):
if type(items)==list:
return ",".join(items)
return ""
def parse_member_databases(dbs):
if type(dbs)==dict:
return ";".join([f"{db}:{','.join(dbs[db])}" for db in dbs.keys()])
return ""
def parse_go_terms(gos):
if type(gos)==list:
return ",".join([go["identifier"] for go in gos])
return ""
def parse_locations(locations):
if type(locations)==list:
return ",".join(
[",".join([f"{fragment['start']}..{fragment['end']}"
for fragment in location["fragments"]
])
for location in locations
])
return ""
def parse_group_column(values, selector):
return ",".join([parse_column(value, selector) for value in values])
def parse_column(value, selector):
if value is None:
return ""
elif "member_databases" in selector:
return parse_member_databases(value)
elif "go_terms" in selector:
return parse_go_terms(value)
elif "children" in selector:
return parse_items(value)
elif "locations" in selector:
return parse_locations(value)
return str(value)
def download_to_file(url, file_path):
#disable SSL verification to avoid config issues
context = ssl._create_unverified_context()
next = url
last_page = False
attempts = 0
while next:
try:
req = request.Request(next, headers={"Accept": "application/json"})
res = request.urlopen(req, context=context)
# If the API times out due a long running query
if res.status == 408:
# wait just over a minute
sleep(61)
# then continue this loop with the same URL
continue
elif res.status == 204:
#no data so leave loop
break
payload = json.loads(res.read().decode())
next = payload["next"]
attempts = 0
if not next:
last_page = True
except HTTPError as e:
if e.code == 408:
sleep(61)
continue
else:
# If there is a different HTTP error, it wil re-try 3 times before failing
if attempts < 3:
attempts += 1
sleep(61)
continue
else:
sys.stderr.write("LAST URL: " + next)
raise e
with open(file_path,"w+") as f:
for i, item in enumerate(payload["results"]):
f.write(parse_column(item["metadata"]["accession"], 'metadata.accession') + "\t")
f.write(parse_column(item["metadata"]["name"], 'metadata.name') + "\t")
f.write(parse_column(item["metadata"]["source_database"], 'metadata.source_database') + "\t")
f.write(parse_column(item["metadata"]["type"], 'metadata.type') + "\t")
f.write(parse_column(item["metadata"]["integrated"], 'metadata.integrated') + "\t")
f.write(parse_column(item["metadata"]["member_databases"], 'metadata.member_databases') + "\t")
f.write(parse_column(item["metadata"]["go_terms"], 'metadata.go_terms') + "\t")
f.write(parse_column(item["proteins"][0]["accession"], 'proteins[0].accession') + "\t")
f.write(parse_column(item["proteins"][0]["protein_length"], 'proteins[0].protein_length') + "\t")
f.write(parse_column(item["proteins"][0]["entry_protein_locations"], 'proteins[0].entry_protein_locations') + "\t")
f.write("\n")
# Don't overload the server, give it time before asking for more
sleep(1)
# 先读取之前已经完成的进度
exist_file_list = set()
already_exist_file = os.listdir("interpro")
for file in already_exist_file:
if file.endswith(".tsv"):
exist_file_list.add(file.split(".")[0])
with open(os.path.join("pids.txt"),"r") as f:
lines = f.readlines()
cnt = 0
for line in lines:
cnt+=1
pid = line.strip()
# 如果之前完成了,就跳过该条信息
# if cnt <= 47:
# continue
if pid in exist_file_list:
print(pid, " exists")
continue
print(line," ",cnt,"/",len(lines))
url = f"https://www.ebi.ac.uk:443/interpro/api/entry/InterPro/protein/reviewed/{pid}/?page_size=200"
download_to_file(url,os.path.join("interpro", pid+'.tsv'))
使用ESM-1b-650M跑embedding,并取第31层的embedding进行保留,注意这里不能直接做sum或者mean的readout,需要整个序列的embedding都保存。
ESM1b代码:(直接更换esm.pretrained后面的模型就可以跑其他的esm模型了)
'''
用于生成序列中每个氨基酸的特征向量
'''
import json
import os
import pickle
import tqdm
import pandas as pd
from pandas import DataFrame as df
import torch
import esm
# 设置CUDA设备编号
device = "cuda:2"
# ESM2预训练模型初始化,取消梯度
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
converter = alphabet.get_batch_converter()
model.to(device)
for p in model.parameters():
p.requires_grad = False
model.eval()
species_name = "xxxx"
# 读取预处理后的文件,这个不重要,只需要根据protein的id来提供序列就可以了
with open(os.path.join(species_name+"_protein_info.json"),'r') as f:
protein_info = json.load(f)
print("start ESM1b, get sequence embeddings!")
cnt = 0
for uid in protein_info.keys():
cnt += 1
print(cnt,'/',len(protein_info))
val = protein_info[uid]
labels, strs, tokens = converter([(uid,val["seq"])])
tokens = tokens.to(device)
seq_emb = model(tokens, repr_layers=[31])['representations'][31]
seq_emb = seq_emb.squeeze(0)
seq_emb = seq_emb.cpu().numpy()
dir_path = os.path.join("dataset",species_name,uid)
os.makedirs(dir_path, exist_ok=True)
with open(os.path.join(dir_path,"seq_emb.pkl"), 'wb') as f:
pickle.dump(seq_emb, f)
del tokens,seq_emb
torch.cuda.empty_cache()
根据仓库中的需要,导出{bp/cc/mf}_{train/valid/test}_pid_list.pkl和pid2esm.pkl。
这个太简单就不给代码了。
数据处理梳理
这一部分踩了很多坑,改了很多代码才调通。
首先generate_points是对所有的数据进行处理的,处理完之后大家都可以读取,所以这里的pid_list.pkl需要包含所有的蛋白质编号。
process_graph需要单独对每一个{bp/cc/mf}_{train/valid/test}_pid_list.pkl做,也就是做出来至少9个文件。大概就是下图所示:
process_interpro,这个可以不用跑,因为在后面的main和pred中它还可以能单独跑,但如果跑了,就需要在yaml配置里面的base填写interpro_whole的路径。
这里因为我们下载下来的interpro是tsv文件,如果直接处理成一行01序列onehot编码会占用很多空间,魔改一下data_utils.py中的get_inter_whole_data函数,让他通过读取标签序列来初始化稀疏矩阵。
def get_inter_whole_data(pid_list, save_file, domain_map, protein_info):
rows = []
cols = []
data = []
for i in trange(len(pid_list)):
pid = pid_list[i]
if pid in protein_info:
domain_list = protein_info[pid]['domain']
for domain in domain_list:
if domain[0] not in domain_map:
continue
domain_id = domain_map[domain[0]]
rows += [i]
cols += [domain_id]
data += [1]
# col_nodes = np.max(cols) + 1
col_nodes = 22369
interpro_matrix = csr_matrix((data, (rows, cols)), shape=(len(pid_list), col_nodes))
with open(save_file, 'wb') as fw:
pkl.dump(interpro_matrix, fw)
print(interpro_matrix.shape)
return interpro_matrix
process_structure,这个看起来是作者自用的处理代码,里面都是本地路径,不用管。
另外一个需要注意的事项,提供给DPFunc的go.txt必须是经过传播的go集合,否则训练效果会很差。
跑代码
首先运行:(注意填写正确的esm和pdb的路径)
python DataProcess/generate_points.py -i data/pid_list.pkl -o data/pid_points.pkl
如果只跑test,使用以下流程:
配置填写:(以bp为例)
name: bp
mlb: ./mlb/bp_go.mlb
results: ./results
base:
pdb_points: ./data/pid_points.pkl
test:
name: test
pid_list_file: ./data/pid_list.pkl
pid_go_file: ./data/bp_go.txt
pid_pdb_file: ./data/PDB/bp_test_whole_pdb_part0.pkl
interpro_file: ./data/bp_interpro.pkl
注意这里跑之前需要把process_graph代码中的输出路径改好
python DataProcess/process_graph.py -d bp
python DataProcess/process_graph.py -d cc
python DataProcess/process_graph.py -d mf
下载已有的DPFunc权重https://drive.google.com/file/d/1V0VTFTiB29ilbAIOZn0okBQWPlbOI3wN/view?usp=drive_link
解压到save_models之后运行预测代码
python DPFunc_pred.py -d bp -n 0 -p DPFunc_model
python DPFunc_pred.py -d cc -n 0 -p DPFunc_model
python DPFunc_pred.py -d mf -n 0 -p DPFunc_model
结果:
bp: Fmax 0.7429, AUPR 0.7608
cc: Fmax 0.8246, AUPR 0.8775
mf: Fmax 0.7238, AUPR 0.7607
跑train+test,使用以下流程:
先要分离train/valid/test的pid_list
填写配置:
name: bp
mlb: ./mlb/bp_go.mlb
results: ./results
base:
interpro_whole:
pdb_points: ./data/pid_points.pkl
train:
name: train
pid_list_file: ./data/bp_train_pid_list.pkl
pid_go_file: ./data/bp_train_go.txt
pid_pdb_file: ./data/PDB/bp_train_whole_pdb_part{}.pkl
train_file_count: 1
interpro_file: ./data/bp_train_interpro.pkl
valid:
name: valid
pid_list_file: ./data/bp_valid_pid_list.pkl
pid_go_file: ./data/bp_valid_go.txt
pid_pdb_file: ./data/PDB/bp_valid_whole_pdb_part0.pkl
interpro_file: ./data/bp_valid_interpro.pkl
test:
name: test
pid_list_file: ./data/bp_test_pid_list.pkl
pid_go_file: ./data/bp_test_go.txt
pid_pdb_file: ./data/PDB/bp_test_whole_pdb_part0.pkl
interpro_file: ./data/bp_test_interpro.pkl
然后还是需要改process_graph的代码,让他把train/valid/test三个都分别读取对应的文件再输出到对应的路径
重新进行process_graph.py
python DPFunc_main.py -d bp -n 0 -e 15 -p temp_model
python DPFunc_main.py -d cc -n 0 -e 15 -p temp_model
python DPFunc_main.py -d mf -n 0 -e 15 -p temp_model
验证集结果:
bp: AUC 0.9989, Fmax 0.5867, AUPR 0.5913, cut-off: 0.55
cc: AUC 0.9990, Fmax 0.7179, AUPR 0.7711, cut-off: 0.58
mf: AUC 0.9992, Fmax 0.7500, AUPR 0.7826, cut-off: 0.45
python DPFunc_pred.py -d bp -n 0 -p temp_model
python DPFunc_pred.py -d cc -n 0 -p temp_model
python DPFunc_pred.py -d mf -n 0 -p temp_model
测试集结果:
bp: AUC 0.9990, Fmax 0.5867, AUPR 0.5997, cut-off: 0.56
cc: AUC 0.9993, Fmax 0.7238, AUPR 0.7690, cut-off: 0.72
mf: AUC 0.9992, Fmax 0.7500, AUPR 0.7839, cut-off: 0.45
吐槽
整个复现花了5~6天的时间,大部分时间都在数据处理上,模型的效果确实很好,训练也很快,但是中途有一些不太好的设计。
1、冗余的dgl图特征,仅仅只是把esm的特征保存在"x"里面,然后用的时候又是拿出来跑网络,这样一份esm特征我们就需要保存两遍,很占空间。一般没有自定义传播函数的时候,建议dgl图就只保存框架,特征单独拿出来放,大部分DGL自带的GCN都是分离节点特征输入的。
2、路径修改繁琐,几个数据处理文件里面,既有读取配置文件的路径,又有写死的本地路径,改一个路径输入输出需要非常仔细才能不出错,有时候改了这里就忘了改那里。
3、函数不统一,训练代码用的读取函数和输出处理函数的名称相同,但是内容略有差别,调用的是不同位置的函数,改完数据处理函数发现没用,调了半天才发现训练里面的数据处理又是独立的。
4、计算Fmax指标时,有一个对预测结果进行传递的过程:
这个非常慢,每一个GO标签都要更新它的传递闭包,并且它是一个一个蛋白质计算的,时间复杂度大概是O(蛋白质个数 * 标签总数 * 平均每个GO标签的传递闭包大小)。
优化的话,可以先求GO图的拓扑序,按照GO拓扑序更新,从具体的GO标签更新到抽象的顶层GO标签,然后预测出来的结果可以作为一个batch更新,而不是一个一个蛋白质更新。
不过我现在还没来得及优化这个,有空再说吧