Giter Club home page Giter Club logo

saprot's Introduction

SaProt: Protein Language Modeling with Structure-aware Vocabulary

The repository is an official implementation of SaProt: Protein Language Modeling with Structure-aware Vocabulary.

If you have any question about the paper or the code, feel free to raise an issue! Saprot should outperform ESM-2 in most tasks under fair evaluation settings.

The laboratory is hiring research assistants, interns, doctoral students, and postdoctoral researchers. Please contact the corresponding author for details.

实验室招聘科研助理,实习生,博士生和博士后,请联系通讯作者

Table of contents

News

  • 2024/04/18: We found a slight difference for EC and GO evaluation and updated the re-evaluated results (see issue #23 for details).
  • 2024/03/08: We uploaded a simple function to make zero-shot prediction of mutational effect (see example below).
  • 2024/01/17: Our paper has been accepted as ICLR 2024 spotlight 🎉🎉🎉!
  • 2023/10/30: We release a pre-trained SaProt 35M model and a 35M residue-sequence-only version of SaProt (for comparison)! The residue-sequence-only SaProt (without 3Di token) performs highly similar to the official ESM-2 35M model. (see Results below).
  • 2023/10/30: We released the results by using ESMFold structures. See Table below

Overview

We propose a structure-aware vocabulary for protein language modeling. The vocabulary is constructed by encoding the protein structure into discrete 3D tokens by using the foldseek. We combine the residue tokens and the structure tokens to form a structure-aware sequence. Through large-scale pre-training, our model, i.e. SaProt, can learn the relationship between the structure and the sequence. For more details, please refer to our paper https://www.biorxiv.org/content/10.1101/2023.10.01.560349v2.

Environment installation

Create a virtual environment

conda create -n SaProt python=3.10
conda activate SaProt

Install packages

bash environment.sh  

Prepare the SaProt model

We provide two ways to use SaProt, including through huggingface class and through the same way in esm github. Users can choose either one to use.

Model checkpoints

Name Size Dataset
SaProt_35M_AF2 35M parameters 40M AF2 structures
SaProt_650M_PDB 650M parameters 40M AF2 structures (phase1) + 60K PDB structures (phase2)
SaProt_650M_AF2 650M parameters 40M AF2 structures

New experimental results

Some experimental results are listed below. For more details, please refer to our paper.

35M Model

Model ClinVar ProteinGym Thermostability HumanPPI Metal Ion Binding EC GO-MF GO-BP GO-CC DeepLoc-Subcellular DeepLoc-Binary
AUC Spearman's ρ Spearman's ρ Acc% Acc% Fmax Fmax Fmax Fmax Acc% Acc%
ESM-2 (35M) 0.722 0.339 0.669 80.79 73.08 0.825 0.616 0.416 0.404 76.58 91.60
SaProt-Seq (35M) 0.738 0.337 0.672 80.56 73.23 0.821 0.608 0.413 0.403 76.67 91.16
SaProt (35M) 0.794 0.392 0.692 81.11 74.29 0.847 0.642 0.431 0.418 78.09 91.97

650M Model

Model ClinVar ProteinGym Thermostability HumanPPI Metal Ion Binding EC GO-MF GO-BP GO-CC DeepLoc-Subcellular DeepLoc-Binary
AUC Spearman's ρ Spearman's ρ Acc% Acc% Fmax Fmax Fmax Fmax Acc% Acc%
ESM-2 (650M) 0.862 0.475 0.680 76.67 71.56 0.868 0.670 0.473 0.470 82.09 91.96
SaProt (650M) 0.909 0.478 0.724 86.41 75.75 0.882 0.682 0.486 0.479 85.57 93.55

AlphaFold2 vs. ESMFold

We compare structures predicted by AF2 or ESMFold, which is shown below:

model ClinVar ProteinGym Thermostability HumanPPI Metal Ion Binding EC GO-MF GO-BP GO-CC DeepLoc-Subcellular DeepLoc-Binary
AUC Spearman's ρ Spearman's ρ Acc% Acc% Fmax Fmax Fmax Fmax Acc% Acc%
SaProt (ESMFold) 0.896 0.455 0.717 85.78 74.10 0.871 0.678 0.480 0.474 82.82 93.19
SaProt (AF2) 0.909 0.478 0.724 86.41 75.75 0.882 0.682 0.486 0.479 85.57 93.55

Load SaProt

Hugging Face model

The following code shows how to load the model based on huggingface class.

from transformers import EsmTokenizer, EsmForMaskedLM

model_path = "/your/path/to/SaProt_650M_AF2"
tokenizer = EsmTokenizer.from_pretrained(model_path)
model = EsmForMaskedLM.from_pretrained(model_path)

#################### Example ####################
device = "cuda"
model.to(device)

seq = "MdEvVpQpLrVyQdYaKv"
tokens = tokenizer.tokenize(seq)
print(tokens)

inputs = tokenizer(seq, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

outputs = model(**inputs)
print(outputs.logits.shape)

"""
['Md', 'Ev', 'Vp', 'Qp', 'Lr', 'Vy', 'Qd', 'Ya', 'Kv']
torch.Size([1, 11, 446])
"""

Load SaProt using esm repository

User could also load SaProt by esm implementation. The checkpoint is stored in the same huggingface folder, named SaProt_650M_AF2.pt. We provide a function to load the model.

from utils.esm_loader import load_esm_saprot

model_path = "/your/path/to/SaProt_650M_AF2.pt"
model, alphabet = load_esm_saprot(model_path)

Convert protein structure into structure-aware sequence

We provide a function to convert a protein structure into a structure-aware sequence. The function calls the foldseek binary file to encode the structure. You can download the binary file from here and place it in the bin folder . The following code shows how to use it.

from utils.foldseek_util import get_struc_seq
pdb_path = "example/8ac8.cif"

# Extract the "A" chain from the pdb file and encode it into a struc_seq
# pLDDT is used to mask low-confidence regions if "plddt_mask" is True
parsed_seqs = get_struc_seq("bin/foldseek", pdb_path, ["A"])["A"]
seq, foldseek_seq, combined_seq = parsed_seqs

print(f"seq: {seq}")
print(f"foldseek_seq: {foldseek_seq}")
print(f"combined_seq: {combined_seq}")

Predict mutational effect

We provide a function to predict the mutational effect of a protein sequence. The example below shows how to predict the mutational effect at a specific position.

from model.esm.esm_foldseek_mutation_model import EsmFoldseekMutationModel


config = {
    "foldseek_path": None,
    "config_path": "/you/path/to/SaProt_650M_AF2",
    "load_pretrained": True,
}
model = EsmFoldseekMutationModel(**config)
tokenizer = model.tokenizer

device = "cuda"
model.eval()
model.to(device)

seq = "MdEvVpQpLrVyQdYaKv"

# Predict the effect of mutating the 3rd amino acid to A
mut_info = "V3A"
mut_value = model.predict_mut(seq, mut_info)
print(mut_value)

# Predict all effects of mutations at 3rd position
mut_pos = 3
mut_dict = model.predict_pos_mut(seq, mut_pos)
print(mut_dict)

# Predict probabilities of all amino acids at 3rd position
mut_pos = 3
mut_dict = model.predict_pos_prob(seq, mut_pos)
print(mut_dict)

"""
0.7908501625061035

{'V3A': 0.7908501625061035, 'V3C': -0.9117952585220337, 'V3D': 2.7700226306915283, 'V3E': 2.3255627155303955, 'V3F': 0.2094242423772812, 'V3G': 2.699633836746216, 'V3H': 1.240191102027893, 'V3I': 0.10231903940439224, 'V3K': 1.804598093032837,
'V3L': 1.3324960470199585, 'V3M': -0.18938277661800385, 'V3N': 2.8249857425689697, 'V3P': 0.40185314416885376, 'V3Q': 1.8361762762069702, 'V3R': 1.1899691820144653, 'V3S': 2.2159857749938965, 'V3T': 0.8813426494598389, 'V3V': 0.0, 'V3W': 0.5853186249732971, 'V3Y': 0.17449656128883362}

{'A': 0.021275954321026802, 'C': 0.0038764977362006903, 'D': 0.15396881103515625, 'E': 0.0987202599644661, 'F': 0.011895398609340191, 'G': 0.14350374042987823, 'H': 0.03334535285830498, 'I': 0.010687196627259254, 'K': 0.058634623885154724, 'L': 0.03656982257962227, 'M': 0.00798324216157198, 'N': 0.16266827285289764, 'P': 0.014419485814869404, 'Q': 0.06051575019955635, 'R': 0.03171204403042793, 'S': 0.08847439289093018, 'T': 0.023291070014238358, 'V': 0.009647775441408157, 'W': 0.017323188483715057, 'Y': 0.011487090960144997}
"""

Prepare dataset

Pre-training dataset

We provide the dataset for pre-training SaProt. The dataset can be downloaded from here.

Downstream tasks

We provide datasets that are used in the paper. Datasets can be downloaded from here.

Once downloaded, the datasets need to be decompressed and placed in the LMDB folder for supervised fine-tuning.

Fine-tune SaProt

We provide a script to fine-tune SaProt on the datasets. The following code shows how to fine-tune SaProt on specific downstream tasks. Before running the code, please make sure that the datasets are placed in the LMDB folder and the huggingface version of SaProt 650M model is placed in the weights/PLMs folder. Note that the default training setting is not as same as in the paper because of the hardware limitation for different users. We recommend users to modify the yaml file flexibly based on their own conditions (i.e. batch_size, devices and accumulate_grad_batches).

# Fine-tune SaProt on the Thermostability task
python scripts/training.py -c config/Thermostability/saprot.yaml

# Fine-tune ESM-2 on the Thermostability task
python scripts/training.py -c config/Thermostability/esm2.yaml

Record the training process (optional)

If you want to record the training process using wandb, you could modify the config file and set Trainer.logger = True and then paste your wandb API key in the config key setting.os_environ.WANDB_API_KEY.

Evaluate zero-shot performance

We provide a script to evaluate the zero-shot performance of models (foldseek binary file is required to be placed in the bin folder):

# Evaluate the zero-shot performance of SaProt on the ProteinGym benchmark
python scripts/mutation_zeroshot.py -c config/ProteinGym/saprot.yaml

# Evaluate the zero-shot performance of ESM-2 on the ProteinGym benchmark
python scripts/mutation_zeroshot.py -c config/ProteinGym/esm2.yaml

The results will be saved in the output/ProteinGym folder.

For ClinVar benchmark, you can use the following script to calculate the AUC metric:

# Evaluate the zero-shot performance of SaProt on the ClinVar benchmark
python scripts/mutation_zeroshot.py -c config/ClinVar/saprot.yaml
python scripts/compute_clinvar_auc.py -c config/ClinVar/saprot.yaml

Citation

If you find this repository useful, please cite our paper:

@article{su2023saprot,
  title={SaProt: Protein Language Modeling with Structure-aware Vocabulary},
  author={Su, Jin and Han, Chenchen and Zhou, Yuyang and Shan, Junjie and Zhou, Xibin and Yuan, Fajie},
  journal={bioRxiv},
  year={2023},
  publisher={Cold Spring Harbor Laboratory}

saprot's People

Contributors

fajieyuan avatar ltenjoy avatar luckydogqaq avatar memgonzales avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

saprot's Issues

Downstream Annotation task 'EC/GO' finetuning overfitting

I noticed during the finetuning process of the downstream task: overfitting phenomenon occurs when the finetuning epoch is set to 100.
For example, when fine-tuning with ESM-2 8M model on GO/BP annotation task, the valid-f1-max keeps rising, but the valid loss overfitting.
image

Test on the lowest valid loss ckpt (like 50 epochs), the test-f1-max = 0.2259
Test on the highest valid-f1-max ckpt (like 100 epochs), the test-f1-max = 0.2092

Therefore, I am curious about downstream finetuning. During training, will you do an early stop based on the overfitting phenomenon?

EC GO results

Dear Sir,

Why is the EC GO result in Saprot so much lower than the original paper "Enhancing Protein Language Model with Structure-based Encoder and Pre-training"? I wonder if the dataset is different.

Getting protein embeddings

Hello, thank you very much for this wonderful work!

I was just wondering if, given a protein sequence and its corresponding 3Di structure-sequence (obtained via Foldseek), is it possible to extract fixed-length protein embeddings using SaProt? Would there be a sample script for this particular task?

Thank you.

License

Hi,

I'm interested in integrating SaProt into an application and was wondering under what license are you releasing the code.

Thanks beforehand!

Additional input values generated by the tokenizer

Thanks for this impressive solution!

When I run the tokenizer on an input sequence, there are always two additional elements added to the tokenizer output. Why is this and what do the values represent?

For example:
print(len(sequence)) -> 5
inputs = tokenizer(sequence, return_tensors="pt")
print(inputs_1['input_ids'].size()) -> torch.Size([1, 7])

Additionally, I'm trying to generate a fixed length sequence embedding. I saw you answered how to do this with the ESM model, but is there a way to do so with the huggingface model?

Thanks for your help!

The meaning of output size

Thank you for your great job and the sharing of the repository!
I wonder the meaning of the output size torch.Size([1, 11, 446]) in your example in README. I suppose '446' is the size of the vocabulary, but why it's different from '441' provided in your article?
Moreover, do you provide codes/scripts to load batches/bulks of sequence into your model like ESM?
Thank you very much!

Pretraining dataset

Awesome work!

Will the entire 40 million-sequence pretaining dataset be made available?

Setting my configuration to for SaProt_650M_AF2

Hi this is a fantastic project.

I want to use your model for a research project I am working on. I am having trouble setting my configuration to where the SaProt_650M_AF2.pt file is stored. How do I get this .pt file from the GitHub?

per residue representations

Dear authors,

Thank you for sharing this great work with us!
I wonder if its possible to extract the per residue representations like with ESM-2?

import torch
import esm

# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

Thank you in advance!

Back Translate 3Di Tokesn to PDB Format

Hi there.
I do not have a great understanding of 3d structure formats and my questions might be very basic to others:
Can we convert the 3Di tokens produced by foldseek to the actual 3D PDB files?
I want to develop a model to get a protein sequence and returns their 3Di tokens as its 3D structure. Imagine that it would be a perfect model in prediction. How can I use the model in real life for 3D structure prediction??
Should I convert the outputs to a specific format?? How can I evaluate the 3D structure prediction of the model with respect to the true 3D structure?

Many thanks to whom can answer my questions.

Some weights of EsmModel were not initialized from the model checkpoint

Hi, thank you very much for sharing this great work.

After loading the pretrained checkpoint, I encounter some warnings:

Some weights of EsmModel were not initialized from the model checkpoint at SaProt_650M_PDB and are newly initialized: ['esm.embeddings.position_embeddings.weight', 'esm.contact_head.regression.bias', 'esm.pooler.dense.bias', 'esm.contact_head.regression.weight', 'esm.pooler.dense.weight']

My code is

from transformers import EsmTokenizer, EsmModel
tokenizer = EsmTokenizer.from_pretrained("SaProt_650M_PDB")
model = EsmModel.from_pretrained("SaProt_650M_PDB").cuda()
model.eval()

I want to use the pre-trained model to obtain representation of a protein directly without finetuning, but I'm not sure if missing these weights affects the quality of the obtained representation, e.g., the absence of position embedding weights.

Fine-tune dataset access

Hello,

I see you have made .mdb dataset files available. How would one go about simply extracting and using the fine-tune data for downstream tasks? I would like to fine-tune my own model so the training script will not work.
Best,
Logan

Clarification Needed on Data Types

Hello, I am currently exploring your project and have two questions regarding the data types:

  1. Raw Data Composition: Does the raw data provided in your repository include 3D coordinates? I check some of the mdb data, only find amino acid sequences with 3D structure embedding sequences. Can you provide raw data with 3D coordinates?
  2. Data Utilization in Downstream Tasks: In the context of downstream tasks, are you primarily using the real data, or are you employing data generated by AlphaFold? I find all data with the PLDDT.
    Answering these questions will help me a lot. Thanks in advance.

Unable to Open Downstream Task MDB Files in Access After Downloading from GitHub Tutorial

After following the tutorial on GitHub to download the data for downstream tasks and extracting it, we encountered an issue. We are unable to open the MDB files provided for the downstream tasks using Access locally. Upon attempting to open these files, we receive an error message, which I will attach below for reference. We are using Access version 2019. We seek clarification on the possible reasons causing this issue and inquire if alternative versions of the data in different formats can be provided.

Error Message:
image

We kindly request your assistance in understanding the root cause of this issue and if possible, providing alternative versions of the data in different formats. This would greatly facilitate our progress in the tasks at hand.

Using ProstT5 or EsmFold to predict foldseek tokens?

1)for some reason esm_foldseek_model.py badly formatted.
2) I can't find config file which correspond to esm_foldseek_model.py
3) I can't find training data for this model on google drive.
4)I really interested to check strategy provided below.

from your publication:
"In Strategy 1, there is a requirement for simultaneous prediction of both residue and Foldseek tokens.
Intuitively, if the accuracy of Foldseek or AF2 structures is insufficient, predicting them simultane-
ously may result in suboptimal outcomes. On the other hand, Strategy 2 intentionally discloses
structure token information during the input phase, allowing the model to primarily concentrate on
predicting residue types. Overall, the performance disparity between the two masking strategies is
not considerable, with Strategy 2 usually exhibiting some improvements"

I would say results are almost identical.

I can suggest Strategy 3 without masking,
input : EsmFold predictions in SA format, output :AF2 predictions in SA format.
as second choice can be taken ProstT5 predictions

what need to be done is training set : it is requires some preprocessing of large amount of data.

compare EsmAtlas data (600M proteins) with AF2 data(200M proteins)
find proteins with identical sequences and convert them into SA (EsmFold and Alphafold2)
my expectation :number of such identical sequences will be large enough.

I don't think we need run training from scratch it can be implemented as fine-tuning procedure.

Fine-tuned model weights

Hi all, thanks for the open-source release! Are you also planning to release the fine-tuned model weights, namely the Thermostability model?

ClinVar query

Hello, I downloaded the ClinVar .tar.gz file from your directory. I noticed that all of the fitness values are '1.0'. The ProteinGym dataset reports various fitness values. Is there a reason you have only kept the '1.0' fitness score ones ?

Finetuning GPU memory cost

Hi Sir,

How about the GPU memory cost when finetuning 650M SaProt model. I got OOM error when I try the finetuning script according to the command which the README file provides. My GPU hardware has 24G memory.

Mismatch between ESM2 pretraining dataset and SaProt pretraining dataset

Hi there, thank you for your impressive work.

I downloaded the pretraining dataset that you published here: https://huggingface.co/datasets/westlake-repl/AF2_UniRef50

However when I load the DB I find that the validation set contains only ~20k uniprots. Your paper says:

"B PRE-TRAINING DATA PROCESSING
We adhere to the procedures outlined in ESM-2 Lin et al. (2022) to generate filtered sequence data,
and then we retrieve all AF2 structures via the AlphaFoldDB website https://alphafold.
ebi.ac.uk/ based on the UniProt ids of protein sequences, collecting approximately 40 million
structures."

And ESM-2 Lin et al. (2022) says:

"A. Materials and Methods
A.1. Data
A.1.1. SEQUENCE DATASET USED TO TRAIN ESM-2
UniRef50, September 2021 version, is used for the training
of ESM models. The training dataset was partitioned by
randomly selecting 0.5% (≈ 250,000) sequences to form
the validation set."

So I am wondering whether there is some data missing (i.e., the remaining ~240k validation uniprots) or if I have done something wrong.

Many thanks in advance.

model for Fine-tune SaProt.

Hello, I'm working at an academy.
I'm trying to run a SaProt following README.

But there is no SaProt model in the weight folder.
I was wondering if you could publicly provide them.

Looking forward to hearing from you.

missing pad_sequences

esm_foldseek_dataset.py:

label_ids = pad_sequences(label_ids, -1)

NameError: name 'pad_sequences' is not defined

please , close issue , I have found function and just copied to my file.

dynamic model/dataset selection is not working

for some reason in my python environment class method function don't see global variable.
Any advice how it can be fixed without changing code will be appreciated.

I have rewritten it using more straightforward "hack"

class ModelInterface:
@classmethod
def init_model(cls, model_py_path: str, **kwargs):
sub_dirs = model_py_path.split(os.sep)
module_name = '.'.join(sub_dirs[:])
module = importlib.import_module(module_name)
objs = dir(module)
model_cls = getattr(module,objs[1])
#"EsmRegressionModel")
return model_cls(**kwargs)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.