Giter Club home page Giter Club logo

chao1224 / moleculestm Goto Github PK

View Code? Open in Web Editor NEW
175.0 4.0 17.0 40.88 MB

Multi-modal Molecule Structure-text Model for Text-based Editing and Retrieval, Nat Mach Intell 2023 (https://www.nature.com/articles/s42256-023-00759-6)

Home Page: https://chao1224.github.io/MoleculeSTM

License: Other

Dockerfile 0.29% Python 99.56% Shell 0.15%
clip computation-chemistry drug-discovery editing pretraining retrieval foundation-model molecule-editing moleculeclip moleculestm

moleculestm's Introduction

MoleculeSTM: Multi-modal Molecule Structure-text Model for Text-based Editing and Retrieval

Authors: Shengchao Liu, Weili Nie, Chengpeng Wang, Jiarui Lu, Zhuoran Qiao, Ling Liu, Jian Tang*, Chaowei Xiao*, Anima Anandkumar*

* jointly supervised

[Paper] [Project Page] [ArXiv] [Datasets on Hugging Face] [Checkpoints on Hugging Face]

1 Environment

First install conda:

wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh

Then create virtual environment and install packages:

conda create -n MoleculeSTM python=3.7
conda activate MoleculeSTM

conda install -y -c rdkit rdkit=2020.09.1.0
conda install -y -c conda-forge -c pytorch pytorch=1.9.1
conda install -y -c pyg -c conda-forge pyg==2.0.3

pip install requests
pip install tqdm
pip install matplotlib
pip install spacy
pip install Levenshtein

# for SciBert
conda install -y boto3
pip install transformers

# for MoleculeNet
pip install ogb==1.2.0

# install pysmilesutils
python -m pip install git+https://github.com/MolecularAI/pysmilesutils.git

pip install deepspeed

# install metagron
# pip install megatron-lm==1.1.5
git clone https://github.com/MolecularAI/MolBART.git --branch megatron-molbart-with-zinc
cd MolBART/megatron_molbart/Megatron-LM-v1.1.5-3D_parallelism
pip install .
cd ../../..

# install apex
# wget https://github.com/NVIDIA/apex/archive/refs/tags/22.03.zip
# unzip 22.03.zip
git clone https://github.com/chao1224/apex.git
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ..

We also provide the docker in Dockerfile.

2 Datasets and Preprocessing

We provide the raw dataset (after preprocessing) at this Hugging Face link. Or you can use the following python script:

from huggingface_hub import HfApi, snapshot_download
api = HfApi()
snapshot_download(repo_id="chao1224/MoleculeSTM", repo_type="dataset", local_dir='.')

Then you can move all the downloaded datasets under ./data folder.

2.1 Pretraining Dataset: PubChemSTM

Useful resources:

  • For molecular structure information (SMILES, 2D molecular graph etc), we can download it from PubChem in SDF format here.
  • For textual data, we may first refer to this PubChem RDF tutorial.
  • The RDF data on the PubChem FTP site is arranged in such a way that you only need to download the type of information in which you are interested, thus allowing you to avoid downloading parts of PubChem data you will not use. For example, if you are just interested in computed chemical properties, you only need to download PubChemRDF data in the compound descriptor directory. The link is here.
  • Guidance on using RDF and REST API can be found here.

As confirmed with PubChem group, performing research on these data is not violating their license; however, PubChem does not possess the license for the textual data, which necessitates an extensive evaluation of the license for each pair of structure-text pair data in PubChemSTM. This task poses a substantial workload and has hindered the release of PubChemSTM. However, we have tried our best to upload the structure part of the PubChemSTM data on Hugging Face, and we also provide all the details to generate PubChemSTM as follows:

  1. Go to preprocessing/PubChemSTM folder.
  2. python step_01_description_extraction.py. This step extracts and merge all the textual descriptions into a single json file. We run this on May 30th, 2022. The APIs will keep updating, so you may have slightly different versions if you run this script yourself.
  3. bash step_02.sh. This will download all the SDF files, with SMILES, 2D graph, and computed molecular properties. This may take hours.
  4. python step_03_filter_out_SDF.py. This will filter all the molecules with textual descriptions and save them int the SDF file. This may take <2 hours.
  5. python step_04_merge_SDF.py. This will gather all the molecules into a single SDF file.
  6. python step_05_sample_extraction.py. This will generate the CID2SMILES.csv file.

2.2 Downstream Datasets

We have included them in the Hugging Face link. We briefly list the details below:

  • DrugBank_data for zero-shot structure-text retrieval
  • ZINC250K_data for space alignment (step 1 in editing)
  • Editing_data for zero-shot text-guided (step 2 in editing)
    • single_multi_property_SMILES.txt for single-objective, multi-objective, binding-affinity-based, and drug relevance editing
    • neighbor2drug for neighborhood searching for patent drug molecules
    • ChEMBL_data for binding editing
  • MoleculeNet_data for molecular property prediction

3 Checkpoints

3.1 SciBERT

This can be done by simplying calling the following for SciBERT:

SciBERT_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder)
SciBERT_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device)

3.2 MegaMolBART

Run download_MegaMolBART.sh (credit to RetMol). The output structure is like:

├── bart_vocab.txt
└── checkpoints
    ├── iter_0134000
    │   ├── mp_rank_00
    │   │   └── model_optim_rng.pt
    │   ├── mp_rank_00_model_states.pt
    │   ├── zero_pp_rank_0_mp_rank_00optim_states.pt
    │   ├── zero_pp_rank_1_mp_rank_00optim_states.pt
    │   ├── zero_pp_rank_2_mp_rank_00optim_states.pt
    │   ├── zero_pp_rank_3_mp_rank_00optim_states.pt
    │   ├── zero_pp_rank_4_mp_rank_00optim_states.pt
    │   ├── zero_pp_rank_5_mp_rank_00optim_states.pt
    │   ├── zero_pp_rank_6_mp_rank_00optim_states.pt
    │   └── zero_pp_rank_7_mp_rank_00optim_states.pt
    └── latest_checkpointed_iteration.txt

3.3 GNN and GraphMVP

For GraphMVP, check this repo, and the checkpoints on Google Drive link.

pretrained_GraphMVP/
├── GraphMVP_C
│   └── model.pth
└── GraphMVP_G
    └── model.pth

3.4 Baseline KV-PLM

For KV-PLM, check this repo and checkpoints on Google Drive link.

3.5 Checkpoints for MoleculeSTM

We provide two sets of demo checkpoints at this huggingface link. Or you can use the following python script:

from huggingface_hub import HfApi, snapshot_download
api = HfApi()
snapshot_download(repo_id="chao1224/MoleculeSTM", repo_type="model", cache_dir='.')

For the optimal results reported in the paper, please use the following script:

from huggingface_hub import HfApi, snapshot_download
api = HfApi()
snapshot_download(repo_id="chao1224/MoleculeSTM", repo_type="model", local_dir='.', allow_patterns="*MoleculeSTM*")

We further provide the optimal checkpoints for each downstream task under the scripts folder (README file).

4 Scripts and Demos

All the running scripts and demos can be found under the scripts folder and demos folder, respectively.

4.1 Pretraining

MoleculeSTM-SMILES

python pretrain.py \
    --verbose --batch_size=8 \
    --molecule_type=SMILES

MoleculeSTM-Graph

python pretrain.py \
    --verbose --batch_size=8 \
    --molecule_type=Graph

4.2 Downstream: Zero-shot Structure-text Retrieval

For DrugBank-Description

MoleculeSTM-SMILES

python downstream_01_retrieval_Description_Pharmacodynamics.py \
    --task=molecule_description_removed_PubChem \
    --molecule_type=SMILES \
    --input_model_dir=../data/demo/demo_checkpoints_SMILES

MoleculeSTM-Graph

python downstream_01_retrieval_Description_Pharmacodynamics.py \
    --task=molecule_description_removed_PubChem \
    --molecule_type=Graph \
    --input_model_dir=../data/demo/demo_checkpoints_Graph

For DrugBank-Pharmacodynamics

MoleculeSTM-SMILES

python downstream_01_retrieval_Description_Pharmacodynamics.py \
    --task=molecule_pharmacodynamics_removed_PubChem \
    --molecule_type=SMILES \
    --input_model_dir=../data/demo/demo_checkpoints_SMILES

MoleculeSTM-Graph

python downstream_01_retrieval_Description_Pharmacodynamics.py \
    --task=molecule_pharmacodynamics_removed_PubChem \
    --molecule_type=Graph \
    --input_model_dir=../data/demo/demo_checkpoints_Graph

For DrugBank-ATC

MoleculeSTM-SMILES

python downstream_01_retrieval_ATC.py \
    --molecule_type=SMILES \
    --input_model_dir=../data/demo/demo_checkpoints_SMILES

MoleculeSTM-Graph

python downstream_01_retrieval_ATC.py \
    --molecule_type=Graph \
    --input_model_dir=../data/demo/demo_checkpoints_Graph

4.3 Downstream: Zero-shot Text-based Molecule Editing

For description id list, you can find them in MoleculeSTM/downstream_molecule_edit_utils.py.

MoleculeSTM-SMILES

python downstream_02_molecule_edit_step_01_MoleculeSTM_Space_Alignment.py \
    --MoleculeSTM_molecule_type=SMILES \
    --MoleculeSTM_model_dir=../data/demo/demo_checkpoints_SMILES


python downstream_02_molecule_edit_step_02_MoleculeSTM_Latent_Optimization.py \
    --MoleculeSTM_molecule_type=SMILES \
    --MoleculeSTM_model_dir=../data/demo/demo_checkpoints_SMILES \
    --language_edit_model_dir=../data/demo/demo_checkpoints_SMILES \
    --input_description_id=101

MoleculeSTM-Graph

python downstream_02_molecule_edit_step_01_MoleculeSTM_Space_Alignment.py \
    --MoleculeSTM_molecule_type=Graph \
    --MoleculeSTM_model_dir=../data/demo/demo_checkpoints_Graph


python downstream_02_molecule_edit_step_02_MoleculeSTM_Latent_Optimization.py \
    --MoleculeSTM_molecule_type=Graph \
    --MoleculeSTM_model_dir=../data/demo/demo_checkpoints_Graph \
    --language_edit_model_dir=../data/demo/demo_checkpoints_Graph \
    --input_description_id=101

4.4 Downstream: Molecular Property Prediction

MoleculeSTM-SMILES

python downstream_03_property_prediction.py \
    --dataset=bace --molecule_type=SMILES \

MoleculeSTM-Graph

python downstream_03_property_prediction.py \
    --dataset=bace --molecule_type=Graph

4.5 Demo

Please check the demos folder. This may require you download the dataset and checkpoints first:

Cite Us

Feel free to cite this work if you find it useful to you!

@article{liu2023moleculestm,
    title={Multi-modal molecule structure-text model for text-based retrieval and editing},
    author={Liu, Shengchao and Nie, Weili and Wang, Chengpeng and Lu, Jiarui and Qiao, Zhuoran and Liu, Ling and Tang, Jian and Xiao, Chaowei and Anandkumar, Anima},
    title={Multi-modal molecule structure--text model for text-based retrieval and editing},
    journal={Nature Machine Intelligence},
    year={2023},
    month={Dec},
    day={01},
    volume={5},
    number={12},
    pages={1447-1457},
    issn={2522-5839},
    doi={10.1038/s42256-023-00759-6},
    url={https://doi.org/10.1038/s42256-023-00759-6}
}

moleculestm's People

Contributors

chao1224 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

moleculestm's Issues

4.3 Downstream: Zero-shot Text-based Molecule Editing

在使用指令python downstream_02_molecule_edit_step_01_MoleculeSTM_Space_Alignment.py
--MoleculeSTM_molecule_type=Graph
--MoleculeSTM_model_dir=../data/demo/demo_checkpoints_Graph
运行MoleculeSTM-Graph编辑任务的时候出现了如下报错。

RuntimeError: Error(s) in loading state_dict for GNN_graphpred:
size mismatch for molecule_node_model.atom_encoder.atom_embedding_list.1.weight: copying a param with shape torch.Size([4, 300]) from checkpoint, the shape in current model is torch.Size([5, 300]).
请问这个该怎么解决啊~

[Understanding] For a beginner or amateur

Curious to Join The Club

Hi, I am new to this Chemistry + LLM works. I find it a little difficult to understand the workings. Could anyone please help to understand and learn more to play with this genre...

Cheers!

Docker ERROR: failed to receive status: rpc error

Hi, thank you for your work.

I aim to use MegaMolBART for encoding and decoding on my work. Therefore, I tried to build required environment using provided Docker file but getting this error:

 => [22/22] RUN cd /tmp/apex/ && pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 
 => => #     g++ -pthread -shared -B /opt/conda/compiler_compat -L/opt/conda/lib -Wl,-rpath=/opt/conda/lib -Wl,--no-as-needed -Wl,--sysroot=/ /tmp/pip-req-build-9per1b4n/build/temp.linux-x86_64-3.7/csrc/flatten_unflatten.o -L/opt/c
 => => # onda/lib/python3.7/site-packages/torch/lib -lc10 -ltorch -ltorch_cpu -ltorch_python -o build/lib.linux-x86_64-3.7/apex_C.cpython-37m-x86_64-linux-gnu.so                                                                       
 => => #     building 'amp_C' extension                                                                                                                                                                                                 
 => => #     Emitting ninja build file /tmp/pip-req-build-9per1b4n/build/temp.linux-x86_64-3.7/build.ninja...                                                                                                                           
 => => #     Compiling objects...                                                                                                                                                                                                       
 => => #     Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)                                                                                                          
ERROR: failed to receive status: rpc error: code = Unavailable desc = error reading from server: EOF
PS D:\pythonProject\ConGen> sudo systemctl status docker   # For systems using systemd

Do you have any idea why this occurs? If not, given that I specifically need to utilize MegaMolBART, do you have any recommendations on how can I achieve that?

wget returns empty zip dictionary for MegaMolBART checkpoints

Hi,
I am trying to get pretrained MegaMolBART files as you described in here.

However, something is not working after wget command:

wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1O5u8b_n93HOrsjN1aezq6NhZojh-6dEe' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1O5u8b_n93HOrsjN1aezq6NhZojh-6dEe" -O models.zip && rm -rf /tmp/cookies.txt

Because, it returns empty zip file. Here, the what I get after wget command:

--2024-01-31 08:35:52-- https://docs.google.com/uc?export=download&confirm=&id=1O5u8b_n93HOrsjN1aezq6NhZojh-6dEe
Resolving docs.google.com (docs.google.com)... 74.125.132.139, 74.125.132.100, 74.125.132.138, ...
Connecting to docs.google.com (docs.google.com)|74.125.132.139|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://drive.usercontent.google.com/download?id=1O5u8b_n93HOrsjN1aezq6NhZojh-6dEe&export=download [following]
--2024-01-31 08:35:52-- https://drive.usercontent.google.com/download?id=1O5u8b_n93HOrsjN1aezq6NhZojh-6dEe&export=download
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 64.233.182.132, 2607:f8b0:4001:c0a::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|64.233.182.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2422 (2.4K) [text/html]
Saving to: ‘models.zip’

models.zip 100%[===================>] 2.37K --.-KB/s in 0s

2024-01-31 08:35:52 (48.1 MB/s) - ‘models.zip’ saved [2422/2422]

Can you help me with this please?

Current conda environment setup instructions do not work

The current instructions for creating a conda environment does not work. In particular, issues arise with pytorch and pytorch-geometric. I have attempted installing with pip instead using:

pip3 install torch==1.9.1 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu111 --no-cache-dir

However, getting the version of pytorch-geometric to match has proven difficult. Using conda as in the install instructions to install pytorch-geometric always disrupts the pytorch version that is installed and causes downstream incompatibilities. I've tried using pip, but neither of the following commands work:

pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-1.9.1+cu111.html
pip3 install torch_geometric torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-1.9.1+cu111.html

OSError: _ZN5torch3jit17parseSchemaOrNameERKSs

Hello, I have installed all required packages following the README
When I run any script in section 4, I got this error which seems indicates a compatibility issue between the installed versions of PyTorch and the torch-sparse package.

Run:

python pretrain.py \
    --verbose --batch_size=8 \
    --molecule_type=Graph

Error:

Traceback (most recent call last):
  File "pretrain.py", line 13, in <module>
    from torch_geometric.loader import DataLoader as pyg_DataLoader
  File "/home/username/miniconda3/envs/MoleculeSTM/lib/python3.7/site-packages/torch_geometric/__init__.py", line 4, in <module>
    import torch_geometric.data
  File "/home/username/miniconda3/envs/MoleculeSTM/lib/python3.7/site-packages/torch_geometric/data/__init__.py", line 1, in <module>
    from .data import Data
  File "/home/username/miniconda3/envs/MoleculeSTM/lib/python3.7/site-packages/torch_geometric/data/data.py", line 3, in <module>
    from torch_geometric.typing import OptTensor, NodeType, EdgeType
  File "/home/username/miniconda3/envs/MoleculeSTM/lib/python3.7/site-packages/torch_geometric/typing.py", line 4, in <module>
    from torch_sparse import SparseTensor
  File "/home/username/miniconda3/envs/MoleculeSTM/lib/python3.7/site-packages/torch_sparse/__init__.py", line 16, in <module>
    f'{library}_{suffix}', [osp.dirname(__file__)]).origin)
  File "/home/username/miniconda3/envs/MoleculeSTM/lib/python3.7/site-packages/torch/_ops.py", line 104, in load_library
    ctypes.CDLL(path)
  File "/home/username/miniconda3/envs/MoleculeSTM/lib/python3.7/ctypes/__init__.py", line 364, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /home/username/miniconda3/envs/MoleculeSTM/lib/python3.7/site-packages/torch_sparse/_version_cpu.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKSs

I also found some useful info for this issue:
pyg-team/pytorch_geometric#999
pyg-team/pytorch_geometric#3836

pytorch=1.9.1, pyg==2.0.3, PackagesNotFoundError

Hi, this is a very cool work, thanks for sharing this repository.
When I followed the following command to install the pytorch and pyg package, I encountered the following problem.

command:
conda install -y -c conda-forge -c pytorch pytorch=1.9.1
conda install -y -c pyg -c conda-forge pyg==2.0.3

problem:
PackagesNotFoundError: The following packages are not available from current channels:

Then I use the search bar at the top of the web page('https://anaconda.org'), but the required version package was not found on the page.

I would like to know if you have encountered the same problem and how did you solve it.
Thank you in advance :)

Processing ".cif" file.

Hi, great work! I wonder if it is possible to input a ".cif" file containing the molecular structure and extract its latent features?

Toy Checkpoints for MoleculeSTM

Hi, this is a very cool work, thanks for sharing this repository.
We found that the performance of the toy checkpoints provided by huggingface is much lower than the results in the paper, whether it is because it has not been trained enough epochs. Could you please provide the details of the hyperparameters of the toy checkpoints?
Thanks so much!

RuntimeError: RNG state is wrong size

Hello!
Thank you for your excellent work! I hope to try the scripts you provided and have downloaded the relevant checkpoints following your tutorial. But when I used python pretrain.py --verbose --batch_size=32 --molecule_type=SMILES --epochs=2 to run the pre-trained script, the following error occurred:

arguments        Namespace(seed=42, device=0, dataspace_path='../data', dataset='PubChemSTM',        
text_type='SciBERT', molecule_type='SMILES', representation_frozen=False, batch_size=32,             
text_lr=0.0001, mol_lr=1e-05, text_lr_scale=1, mol_lr_scale=1, num_workers=8, epochs=2, decay=0,     
verbose=True, output_model_dir=None, max_seq_len=512,                                                
megamolbart_input_dir='../data/pretrained_MegaMolBART/checkpoints',                                  
vocab_path='../MoleculeSTM/bart_vocab.txt', pretrain_gnn_mode='GraphMVP_G', gnn_emb_dim=300,         
num_layer=5, JK='last', dropout_ratio=0.5, gnn_type='gin', graph_pooling='mean', SSL_loss='EBM_NCE', 
SSL_emb_dim=256, CL_neg_samples=1, T=0.1, normalize=True)                                            
len of CID2text: 324292                                                                              
len of CID2SMILES: 324270                                                                            
CID 28145 missing
CID 24606 missing
CID 24594 missing
CID 61654 missing
CID 24637 missing
CID 28117 missing
CID 21863527 missing
CID 61861 missing
CID 24258 missing
CID 61851 missing
CID 28127 missing
CID 28116 missing
CID 6857667 missing
CID 5460533 missing
CID 5460520 missing
CID 139646 missing
CID 5460519 missing
CID 11966241 missing
CID 24906310 missing
CID 13847619 missing
CID 60211070 missing
CID 28299 missing
missing 22
len of text_list: 361885
using world size: 1 and model-parallel size: 1 
using torch.float32 for parameters ...
-------------------- arguments --------------------
  adam_beta1 ...................... 0.9
  adam_beta2 ...................... 0.999
  adam_eps ........................ 1e-08
  adlr_autoresume ................. False
  adlr_autoresume_interval ........ 1000
  apply_query_key_layer_scaling ... False
  apply_residual_connection_post_layernorm  False
  attention_dropout ............... 0.1
  attention_softmax_in_fp32 ....... False
  batch_size ...................... None
  bert_load ....................... None
  bias_dropout_fusion ............. False
  bias_gelu_fusion ................ False
  block_data_path ................. None
  checkpoint_activations .......... False
  checkpoint_in_cpu ............... False
  checkpoint_num_layers ........... 1
  clip_grad ....................... 1.0
  contigious_checkpointing ........ False
  cpu_optimizer ................... False
  cpu_torch_adam .................. False
  data_impl ....................... infer
  data_path ....................... None
  dataset_path .................... None
  DDP_impl ........................ local
  deepscale ....................... False
  deepscale_config ................ None
  deepspeed ....................... False
  deepspeed_activation_checkpointing  False
  deepspeed_config ................ None
  deepspeed_mpi ................... False
  distribute_checkpointed_activations  False                                                         
  distributed_backend ............. nccl                                                             
  dynamic_loss_scale .............. True                                                             
  eod_mask_loss ................... False
  eval_interval ................... 1000
  eval_iters ...................... 100
  exit_interval ................... None
  faiss_use_gpu ................... False
  finetune ........................ False
  fp16 ............................ False
  fp16_lm_cross_entropy ........... False
  fp32_allreduce .................. False
  gas ............................. 1
  hidden_dropout .................. 0.1
  hidden_size ..................... 256
  hysteresis ...................... 2
  ict_head_size ................... None
  ict_load ........................ None
  indexer_batch_size .............. 128
  indexer_log_interval ............ 1000
  init_method_std ................. 0.02
  load ............................ ../data/pretrained_MegaMolBART/checkpoints                       
  local_rank ...................... None
  log_interval .................... 100
  loss_scale ...................... None
  loss_scale_window ............... 1000
  lr .............................. None
  lr_decay_iters .................. None
  lr_decay_style .................. linear
  make_vocab_size_divisible_by .... 128
  mask_prob ....................... 0.15
  max_position_embeddings ......... 512
  merge_file ...................... None
  min_lr .......................... 0.0
  min_scale ....................... 1
  mmap_warmup ..................... False
  model_parallel_size ............. 1
  no_load_optim ................... False
  no_load_rng ..................... False
  no_save_optim ................... False
  no_save_rng ..................... False
  num_attention_heads ............. 8
  num_layers ...................... 4
  num_unique_layers ............... None
  num_workers ..................... 2
  onnx_safe ....................... None
  openai_gelu ..................... False
  override_lr_scheduler ........... False
  param_sharing_style ............. grouped
  params_dtype .................... torch.float32
  partition_activations ........... False
  pipe_parallel_size .............. 0
  profile_backward ................ False
  query_in_block_prob ............. 0.1
  rank ............................ 0
  report_topk_accuracies .......... []
  reset_attention_mask ............ False
  reset_position_ids .............. False
  save ............................ None
  save_interval ................... None
  scaled_masked_softmax_fusion .... False
  scaled_upper_triang_masked_softmax_fusion  False 
  seed ............................ 1234
  seq_length ...................... None
  short_seq_prob .................. 0.1
  split ........................... 969, 30, 1
  synchronize_each_layer .......... False
  tensorboard_dir ................. None
  titles_data_path ................ None
  tokenizer_type .................. GPT2BPETokenizer
  train_iters ..................... None
  use_checkpoint_lr_scheduler ..... False
  use_cpu_initialization .......... False
  use_one_sent_docs ............... False
  vocab_file ...................... ../MoleculeSTM/bart_vocab.txt
  warmup .......................... 0.01
  weight_decay .................... 0.01
  world_size ...................... 1
  zero_allgather_bucket_size ...... 0.0
  zero_contigious_gradients ....... False
  zero_reduce_bucket_size ......... 0.0
  zero_reduce_scatter ............. False
  zero_stage ...................... 1.0
---------------- end of arguments ---------------- 
> initializing torch distributed ...
> initializing model parallel with size 1                                                            
> setting random seeds to 1234 ...                                                                   
> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel r
ank 0 with model parallel seed: 3952 and data parallel seed: 1234                                    
Loading vocab from ../MoleculeSTM/bart_vocab.txt.
Loading from ../data/pretrained_MegaMolBART/checkpoints
global rank 0 is loading checkpoint ../data/pretrained_MegaMolBART/checkpoints/iter_0134000/mp_rank_0
0/model_optim_rng.pt
could not find arguments in the checkpoint ...
Traceback (most recent call last):
  File "/data02/luozc/chemgpt/MoleculeSTM/scripts/pretrain.py", line 270, in <module>
    MegaMolBART_wrapper = MegaMolBART(
  File "/data02/luozc/chemgpt/MoleculeSTM/scripts/../MoleculeSTM/models/mega_molbart/mega_mol_bart.py
", line 98, in __init__
    self.model = self.load_model(args, self.tokenizer, decoder_max_seq_len)
  File "/data02/luozc/chemgpt/MoleculeSTM/scripts/../MoleculeSTM/models/mega_molbart/mega_mol_bart.py
", line 157, in load_model
    self.iteration = load_checkpoint(model, None, None)
  File "/home/oem/anaconda3/envs/chemgpt/lib/python3.10/site-packages/megatron/checkpointing.py", lin
e 287, in load_checkpoint
    torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
  File "/home/oem/anaconda3/envs/chemgpt/lib/python3.10/site-packages/torch/cuda/random.py", line 75,
 in set_rng_state
    _lazy_call(cb)
  File "/home/oem/anaconda3/envs/chemgpt/lib/python3.10/site-packages/torch/cuda/__init__.py", line 2
29, in _lazy_call
    callable()
  File "/home/oem/anaconda3/envs/chemgpt/lib/python3.10/site-packages/torch/cuda/random.py", line 73,
 in cb
    default_generator.set_state(new_state_copy)
RuntimeError: RNG state is wrong size

How should I fix it? I'm using a server with NVIDIA H800, which has cuda==12.1 and pytorch==2.1.2
Thanks again 🙏

Hyperparameters for pre-training

Hi,
First of all thank you for your work! I noticed that the default hyperparameters for the pre-training scripts seem to be inconsistent with the paper, or some are not mentioned in the paper.
Do you confirm that these are the hyperparameters that were ultimately chosen? If not is it possible to provide the hyperparameters for pre-training?
Thank you.

Code available?

Hi, thank you for opening this repository.
Your work is quite interesting!
I would like to run your code from the pretraining and downstream tasks, and then is there any expectation when your dataset and code will be publicly available?

Thank you in advance:)

How do I get the latest "CID2SMILES.csv" data file

Hi, this is a very cool work, thanks for sharing this repository.

The code on github does not generate the latest "CID2SMILES.csv" file, except for an earlier version uploaded on huggingface. How do I get the latest version of this file?

Thank you in advance :)

Missing downsload_MegaBolBART.sh

Hello, thank you for sharing this fantastic work.

I noticed that during the pretraining stage, as outlined in the paper, the SMILES encoder uses MegaMolBART's encoder. However, I could not find the downsload_MegaBolBART.sh file in this repository. Additionally, I checked the paper you cited in MoleculesSTM, which is "Chemformer: a pre-trained transformer for computational chemistry," but it seems that they did not provide the model checkpoint.

Could you please advise me on how I can obtain the pretrained checkpoint for the SMILES encoder that uses MegaMolBART's encoder?

Unexpected key(s) in state_dict: "embeddings.position_ids"

Hello, I have downloaded the toy checkpoint that you provided.
However, when I try to run the script "downstream_01_retrieval_ATC_Retrieval.py", I have noticed that the text model dict does not match with "scibert_scivocab_uncased". Is the text model slightly different from Scibert?

image

Where is CID2text.json file?

Thanks for sharing the great work. I was trying to run the demo_pretrain_Graph.ipynb, and cannot run the last step.
Screenshot 2024-05-04 at 12 48 35 PM

I found in /MoleculeSTM/datasets/PubChemSTM.py
class PubChemSTM_Datasets_Graph(InMemoryDataset),
the dataset you use is
self.SDF_file_path = os.path.join(self.root, "raw/molecules.sdf")
self.CID2text_file = os.path.join(self.root, "raw/CID2text.json")

but I cannot find the raw/CID2text.json under the data/PubChemSTM_data folder. I was wondering if the error above relates to the missing of this file. Please let me know if you have any concerns. thanks!

file missing

hi,
I'm very interested in your work,
but in the process of reproducing it, I can't find this file: downsload_MegaBolBART.sh

Docker Failed (RDKit)

Issue: Docker fails at installing rdkit.
Attempted Workarounds: Tried removing the version of rdkit, also tried using pip install rdkit.

docker build -t molecule_stm .
[+] Building 40.7s (13/27)                                 docker:default
 => [internal] load build definition from Dockerfile                 0.0s
 => => transferring dockerfile: 1.39kB                               0.0s 
 => [internal] load metadata for nvcr.io/nvidia/pytorch:22.01-py3    1.6s 
 => [internal] load .dockerignore                                    0.0s
 => => transferring context: 2B                                      0.0s 
 => [ 1/24] FROM nvcr.io/nvidia/pytorch:22.01-py3@sha256:06f27ba669  0.0s 
 => CACHED [ 2/24] RUN useradd -ms /bin/bash shengchaol              0.0s 
 => CACHED [ 3/24] WORKDIR /home/shengchaol                          0.0s 
 => CACHED [ 4/24] RUN chmod -R 777 /home/shengchaol                 0.0s 
 => CACHED [ 5/24] RUN chmod -R 777 /usr/bin                         0.0s 
 => CACHED [ 6/24] RUN chmod -R 777 /bin                             0.0s 
 => CACHED [ 7/24] RUN chmod -R 777 /usr/local                       0.0s 
 => CACHED [ 8/24] RUN chmod -R 777 /opt/conda                       0.0s 
 => CACHED [ 9/24] RUN conda install -y python=3.7                   0.0s 
 => ERROR [10/24] RUN conda install -y -c rdkit rdkit=2020.09.1.0   39.1s 
------
 > [10/24] RUN conda install -y -c rdkit rdkit=2020.09.1.0:
1.165 Collecting package metadata (current_repodata.json): ...working... done
14.62 Solving environment: ...working... failed with initial frozen solve. Retrying with flexible solve.
30.82 Solving environment: ...working... failed with repodata from current_repodata.json, will retry with next repodata source.
38.38
38.38 ResolvePackageNotFound:
38.38   - conda==4.11.0
38.38
------
Dockerfile:20
--------------------
  18 |     RUN conda install -y python=3.7
  19 |
  20 | >>> RUN conda install -y -c rdkit rdkit=2020.09.1.0
  21 |     # RUN pip install rdkit
  22 |     RUN conda install -y -c conda-forge -c pytorch pytorch=1.9.1   
--------------------
ERROR: failed to solve: process "/bin/sh -c conda install -y -c rdkit rdkit=2020.09.1.0" did not complete successfully: exit code: 1

AssertionError in step 1 in preprocessing

Hi, thank you very much for sharing this great work.

When I run python step_01_description_extraction.py in preprocessing the PubChemSTM dataset in step 1, an assertion error occurred:

Traceback (most recent call last):                                                                                                                                                                                                              
  File "step_01_description_extraction.py", line 162, in <module>                                                                                                                                                                               
    assert description_data["TotalPages"] == total_page_num                                                                                                                                                                                     
AssertionError

I opened the https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/annotations/heading/json?heading_type=Compound&heading=Record+Description&page=0 and found the "TotalPages" is 386. I wonder if this affects any other parts of the data preprocessing code and if there is any plan to provide up-to-date code for processing the data?

demos

Following up on my previous post, I tried to run all the demos, however, for each of them I failed at its last step, and received the pretty much the similar error.
`
TypeError Traceback (most recent call last)
in <cell line: 9>()
8
9 for epoch in range(1, args.epochs + 1):
---> 10 loss_acc = train_func(model, device, train_loader, optimizer)
11 print("Epoch: {}\nLoss: {}".format(epoch, loss_acc))
12

5 frames
/usr/local/lib/python3.10/dist-packages/torch/_utils.py in reraise(self)
720 # instantiate since we don't know how to
721 raise RuntimeError(msg) from None
--> 722 raise exception
723
724

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 51, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.10/dist-packages/torch_geometric/data/dataset.py", line 289, in getitem
data = self.get(self.indices()[idx])
File "/content/drive/MyDrive/MoleculeSTM/MoleculeSTM/datasets/MoleculeNet_Graph.py", line 184, in get
for key in self.data.keys:
TypeError: 'method' object is not iterable
`
Could you provide some insights on how to address this issue? Thank you!

About mol_to_graph_data_obj_simple functions

Thank you for an interesting repo.

I went through the code and I noticed that you used two different mol_to_graph_data_obj_simple functions for contrastive pre-training and property prediction fine-tuning.
pre-training: https://github.com/chao1224/MoleculeSTM/blob/main/MoleculeSTM/datasets/utils.py#L44
fine-tuning: https://github.com/chao1224/MoleculeSTM/blob/main/MoleculeSTM/datasets/MoleculeNet_Graph.py#L17

Could you explain why we have to do that? While you used the same GNN architecture for pre-training and fine-tuning, does using different mol_to_graph_data_obj_simple functions affect the GNN's behavior?

Looking forward to hearing from you soon.

Thanks.

How can I reproduce the results from the article?

作者你好!我在使用您给出的code和checkpoint进行molecule editing,但是我使用默认参数似乎无法复现出文章里给出的结果,请问是我的哪些参数设置有问题吗?:)

Not getting expected results for zero-shot molecule editing

For zero-shot molecule editing, I am not getting the expected results as shown in the repository's notebook.
Do you have a suggestion to improve this result.

This is what I get.

===== for text prompt: This molecule is soluble in water. =====
===== for SMILES OC1C2C1CC2 =====
WARNING: MOLECULE VALIDATION AND SANITIZATION CURRENTLY DISABLED
Use random noise for init
l2 lambda: 1.0
Use random noise for init
100%|██████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.92it/s]
clip loss: -0.45563	L2 loss: 0.23571
WARNING: MOLECULE VALIDATION AND SANITIZATION CURRENTLY DISABLED
SMILES_list: ['OC1C2C1CC2', '[73Se][O+][O+][O+][O+][O+][O+][O+][73Se][73Se][73Se][73Se][73Se][F-][32P][73Se][F-][32P][73Se][O+][O+][73Se][F-][73Se][F-][73Se][73Se][73Se][73Se][73Se][73Se][F-][32P][73Se][F-]1[SiH3-][73Se]', '[73Se]S[3H][Rb][85SrH2][SiH3-]LogD_change_(0.9, 1.1][SiH3-][C@H][3H][Rb][85SrH2][SiH3-][I-][123I][73Se][123I][73Se][123I][73Se]<UNUSED_166>[F-]<UNUSED_81>Clint_low->high[F-]<UNUSED_81>[C@H][3H][Rb][85SrH2][SiH3-][85SrH2][SiH3-][85SrH2][SiH3-][85SrH2][SiH3-][I-][123I][73Se]S[3H][Rb][85SrH2][SiH3-][SiH][SiH2]<UNUSED_0>Clint_low->high[F-]<UNUSED_81>Clint_low->high[KH][11CH]LogD_change_(0.9, 1.1]Br<UNUSED_0>Clint_low->highLogD_change_(0.9, 1.1]Br[Se][123I][85SrH2][C@H][3H][Rb][85SrH2][C@H][3H][Rb][85SrH2][C@H][3H][Rb][85SrH2][C@H][C@H][C@H][C@H][3H][Rb][85SrH2][O+][O+][O+][O+][O+][O+][O+][O+][O+][O+][O+][O+][O+][123I][73Se]S[3H]Clint_low->high<UNUSED_64><UNUSED_81>[C@H]<UNUSED_61><UNUSED_81>Clint_low->high[KH]<UNUSED_46><UNUSED_61>[C@H][3H][123I][O+][123I][O+][123I][P@@+]LogD_change_(1.1, 1.3][PH][C@H][3H][Rb][85SrH2][C@H][3H][Rb][C@H][3H][Rb][C@H][3H][Rb][C@H][3H][123I][73Se]S[3H]Clint_low->high[KH]Clint_low->high[KH][123I]/[Al]<UNUSED_182>[18OH][3H][123I]S[3H][123I]S[3H][123I][P@@+][KH]LogD_change_(2.1, 2.3]LogD_change_(2.1, 2.3]LogD_change_(2.1, 2.3]LogD_change_(2.1, 2.3]<UNUSED_161>[123I][73Se][123I][73Se][123I][73Se]S[3H][32P][PH][C@H][3H]Clint_low->high[KH][S@@+]S[3H]<UNUSED_0>S[3H]<UNUSED_0>S[3H]<UNUSED_0>S[3H]Clint_low->high[KH][S@@+]S[3H]<UNUSED_0>S[3H][Rb][C@H][3H][Rb][C@H][3H][Rb][C@H][3H][Rb][C@H][3H][Rb][85SrH2][SiH3-]S[3H][SiH][PH][3H][32P][PH][3H][Rb][C@H][3H][32P][PH][3H][32P][PH][3H][32P][PH][S@@][32P][PH][3H]<UNUSED_0>LogD_change_(2.1, 2.3]<UNUSED_81>[C@H][3H][Rb][C@H][3H]Clint_low->high<UNUSED_64>[C@H][3H][Rb][C@H][3H][Rb][85SrH2][C@H][3H][Rb][85SrH2][C@H][3H][Rb][C@H][3H][Rb][C@H][3H][Rb][C@H][3H][Rb][C@H][3H][Rb][C@H][SiH][123I][P@@+]LogD_change_(0.3, 0.5][PH]S[3H][Rb][85SrH2][O+][O+][O+][123I][85SrH2][PH][3H][Rb][85SrH2][123I][85SrH2][123I][85SrH2][PH][3H][Rb][C@H][3H][Rb][85SrH2][PH][OH+][NH4+]<UNUSED_193>[PH]<UNUSED_193>[OH+][NH4+]<UNUSED_193>LogD_change_(0.9, 1.1][P@@+][As-][PH][3H][Rb][C@H][3H][Rb][C@H][3H]Clint_low->high[PH][3H][Rb][C@H][3H][Rb][C@H][3H][Rb][C@H][3H][Rb][85SrH2][PH][3H][Rb][C@H][3H][Rb][C@H][3H][Rb][C@H][3H][Rb][C@H][PH][AsH3][C@H][3H][Rb][C@H][3H][Rb][C@H][3H][Rb][C@H][SiH][PH][AsH3][C@H][3H][Rb][85SrH2]<UNUSED_64>[PH][85SrH2][PH][3H][Rb][85SrH2][PH][C@H][3H][Rb][C@H][3H][Rb][85SrH2][O+]LogD_change_(2.1, 2.3][F-]<UNUSED_81>[C@H][SiH][123I][As-][PH][C@H][SiH][PH][AsH3][C@H][SiH][PH][P@@+][As-][PH][C@H][3H][Rb][85SrH2][PH][3H][Rb][85SrH2][PH][3H][Rb][C@H][SiH][123I]<UNUSED_193><UNUSED_68><UNUSED_0>[As-][PH][S@@]<UNUSED_46>[PH]<UNUSED_193>LogD_change_(0.9, 1.1][TeH2][3H][NH4+]<UNUSED_193>[PH][3H][Rb][85SrH2][O+][123I][As-][PH][3H][Rb][85SrH2][PH][P@@+][As-][PH][P@@+][P@@+][P@@+][As-][PH][P@@+][As-][PH][3H][TeH2][3H][Rb][85SrH2][PH][3H][Rb][85SrH2][PH][3H][Rb][85SrH2][PH][P@@+][As-][PH][85SrH2][SiH3-][85SrH2][PH]<UNUSED_0>[TeH2][3H]<UNUSED_0>[TeH2][3H]<UNUSED_0>[TeH2][3H][Rb][85SrH2][PH]LogD_change_(3.5, 3.7][PH][85SrH2][PH][85SrH2][PH][85SrH2][PH][C@H][PH][C@H][PH][P@@+]LogD_change_(0.3, 0.5][PH][P@@+]LogD_change_(0.3, 0.5][PH][P@@+][PH][P@@+][PH][85SrH2][PH][85SrH2][PH]']

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.