Giter Club home page Giter Club logo

targetdiff's People

Contributors

amorehead avatar guanjq avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

targetdiff's Issues

binding affinity embedding dimesion issue

I moved scripts/property_prediction/inference.py into targetdiff directory.
I used "python inference.py --ckpt_path pretrained_models/egnn_pdbbind_v2016.pt --protein_path examples/1h36_A_rec_1h36_r88_lig_tt_docked_0_pocket10.pdb --ligand_path examples/1h36_A_rec_1h36_r88_lig_tt_docked_0.sdf"

I encounted
RuntimeError: Error(s) in loading state_dict for PropPredNet:
size mismatch for ligand_atom_emb.weight: copying a param with shape torch.Size([256, 30]) from checkpoint, the shape in current model is torch.Size([256, 31]).

The embedding dimension is different from each other.

in case of my own binding complex

I used AI ligand-protein docking tool.
So now I have a ligand sdf file that have docking coordiantes.

I want to extract pocket but the method mentioned on github is confused in this case.

How can I extract pocket in this situation that there are protein pdb file and ligand sdf file.

Finally your advice of environment is very good.
I share yaml content.

name: targetdiff
channels:

  • pytorch
  • nvidia
  • conda-forge
  • defaults
    dependencies:
  • _libgcc_mutex=0.1=conda_forge
  • _openmp_mutex=4.5=2_kmp_llvm
  • absl-py=1.4.0=pyhd8ed1ab_0
  • aiohttp=3.8.4=py38h01eb140_1
  • aiosignal=1.3.1=pyhd8ed1ab_0
  • async-timeout=4.0.2=pyhd8ed1ab_0
  • attrs=23.1.0=pyh71513ae_1
  • blas=1.0=mkl
  • blinker=1.6.2=pyhd8ed1ab_0
  • boost=1.74.0=py38h2b96118_5
  • boost-cpp=1.74.0=h75c5d50_8
  • brotli-python=1.0.9=py38hfa26641_9
  • bzip2=1.0.8=h7b6447c_0
  • c-ares=1.19.1=hd590300_0
  • ca-certificates=2023.5.7=hbcca054_0
  • cachetools=4.2.4=pyhd8ed1ab_0
  • cairo=1.16.0=hb05425b_5
  • certifi=2023.5.7=pyhd8ed1ab_0
  • cffi=1.15.1=py38h4a40e3a_3
  • charset-normalizer=2.0.4=pyhd3eb1b0_0
  • click=8.1.3=unix_pyhd8ed1ab_2
  • cryptography=41.0.1=py38hcdda232_0
  • cuda=11.6.1=0
  • cuda-cccl=11.6.55=hf6102b2_0
  • cuda-command-line-tools=11.6.2=0
  • cuda-compiler=11.6.2=0
  • cuda-cudart=11.6.55=he381448_0
  • cuda-cudart-dev=11.6.55=h42ad0f4_0
  • cuda-cuobjdump=11.6.124=h2eeebcb_0
  • cuda-cupti=11.6.124=h86345e5_0
  • cuda-cuxxfilt=11.6.124=hecbf4f6_0
  • cuda-driver-dev=11.6.55=0
  • cuda-gdb=12.2.53=0
  • cuda-libraries=11.6.1=0
  • cuda-libraries-dev=11.6.1=0
  • cuda-memcheck=11.8.86=0
  • cuda-nsight=12.2.53=0
  • cuda-nsight-compute=12.2.0=0
  • cuda-nvcc=11.6.124=hbba6d2d_0
  • cuda-nvdisasm=12.2.53=0
  • cuda-nvml-dev=11.6.55=haa9ef22_0
  • cuda-nvprof=12.2.60=0
  • cuda-nvprune=11.6.124=he22ec0a_0
  • cuda-nvrtc=11.6.124=h020bade_0
  • cuda-nvrtc-dev=11.6.124=h249d397_0
  • cuda-nvtx=11.6.124=h0630a44_0
  • cuda-nvvp=12.2.60=0
  • cuda-runtime=11.6.1=0
  • cuda-samples=11.6.101=h8efea70_0
  • cuda-sanitizer-api=12.2.53=0
  • cuda-toolkit=11.6.1=0
  • cuda-tools=11.6.1=0
  • cuda-visual-tools=11.6.1=0
  • cycler=0.11.0=pyhd8ed1ab_0
  • easydict=1.9=py_0
  • ffmpeg=4.3=hf484d3e_0
  • fontconfig=2.14.1=hef1e5e3_0
  • freetype=2.10.4=hca18f0e_2
  • frozenlist=1.3.3=py38h0a891b7_0
  • gds-tools=1.7.0.149=0
  • giflib=5.2.1=h5eee18b_3
  • glib=2.69.1=he621ea3_2
  • gmp=6.2.1=h295c915_3
  • gnutls=3.6.15=he1e5248_0
  • google-auth=1.35.0=pyh6c4a22f_0
  • google-auth-oauthlib=1.0.0=pyhd8ed1ab_0
  • greenlet=2.0.2=py38h17151c0_1
  • grpcio=1.56.0=py38h94a1851_2
  • icu=70.1=h27087fc_0
  • idna=3.4=pyhd8ed1ab_0
  • importlib-metadata=6.7.0=pyha770c72_0
  • intel-openmp=2023.1.0=hdb19cb5_46305
  • jpeg=9e=h5eee18b_1
  • kiwisolver=1.4.4=py38h43d8883_1
  • lame=3.100=h7b6447c_0
  • lcms2=2.12=h3be6417_0
  • ld_impl_linux-64=2.38=h1181459_1
  • lerc=4.0.0=h27087fc_0
  • libabseil=20230125.3=cxx17_h59595ed_0
  • libcublas=11.9.2.110=h5e84587_0
  • libcublas-dev=11.9.2.110=h5c901ab_0
  • libcufft=10.7.1.112=hf425ae0_0
  • libcufft-dev=10.7.1.112=ha5ce4c0_0
  • libcufile=1.7.0.149=0
  • libcufile-dev=1.7.0.149=0
  • libcurand=10.3.3.53=0
  • libcurand-dev=10.3.3.53=0
  • libcusolver=11.3.4.124=h33c3c4e_0
  • libcusparse=11.7.2.124=h7538f96_0
  • libcusparse-dev=11.7.2.124=hbbe9722_0
  • libdeflate=1.17=h5eee18b_0
  • libffi=3.4.4=h6a678d5_0
  • libgcc-ng=13.1.0=he5830b7_0
  • libgrpc=1.56.0=h3905398_2
  • libiconv=1.16=h7f8727e_2
  • libidn2=2.3.4=h5eee18b_0
  • libnpp=11.6.3.124=hd2722f0_0
  • libnpp-dev=11.6.3.124=h3c42840_0
  • libnsl=2.0.0=h7f98852_0
  • libnvjpeg=11.6.2.124=hd473ad6_0
  • libnvjpeg-dev=11.6.2.124=hb5906b9_0
  • libpng=1.6.39=h5eee18b_0
  • libprotobuf=4.23.3=hd1fb520_0
  • libsqlite=3.42.0=h2797004_0
  • libstdcxx-ng=13.1.0=hfd8a6a1_0
  • libtasn1=4.19.0=h5eee18b_0
  • libtiff=4.5.0=h6adf6a1_2
  • libunistring=0.9.10=h27cfd23_0
  • libuuid=2.38.1=h0b41bf4_0
  • libwebp=1.2.4=h11a3e52_1
  • libwebp-base=1.2.4=h5eee18b_1
  • libxcb=1.15=h0b41bf4_0
  • libxml2=2.9.14=h22db469_4
  • libzlib=1.2.13=hd590300_5
  • llvm-openmp=16.0.6=h4dfa4b3_0
  • lz4-c=1.9.4=h6a678d5_0
  • markdown=3.4.3=pyhd8ed1ab_0
  • markupsafe=2.1.3=py38h01eb140_0
  • matplotlib-base=3.4.3=py38hf4fb855_1
  • mkl=2023.1.0=h6d00ec8_46342
  • mkl-service=2.4.0=py38h5eee18b_1
  • mkl_fft=1.3.6=py38h417a72b_1
  • mkl_random=1.2.2=py38h417a72b_1
  • multidict=6.0.4=py38h1de0b5d_0
  • ncurses=6.4=h6a678d5_0
  • nettle=3.7.3=hbbd107a_1
  • nsight-compute=2023.2.0.16=0
  • oauthlib=3.2.2=pyhd8ed1ab_0
  • openbabel=3.1.1=py38h3d1cf2f_4
  • openh264=2.1.1=h4ff587b_0
  • openssl=3.1.1=hd590300_1
  • pandas=2.0.3=py38h01efb38_0
  • pcre=8.45=h9c3ff4c_0
  • pillow=9.4.0=py38h6a678d5_0
  • pip=23.1.2=pyhd8ed1ab_0
  • pixman=0.40.0=h36c2ea0_0
  • protobuf=4.23.3=py38h830738e_0
  • pthread-stubs=0.4=h36c2ea0_1001
  • pyasn1=0.4.8=py_0
  • pyasn1-modules=0.2.7=py_0
  • pycairo=1.24.0=py38h1a1917b_0
  • pycparser=2.21=pyhd3eb1b0_0
  • pyjwt=2.7.0=pyhd8ed1ab_0
  • pyopenssl=23.2.0=pyhd8ed1ab_1
  • pyparsing=3.1.0=pyhd8ed1ab_0
  • pysocks=1.7.1=pyha2e5f31_6
  • python=3.8.17=he550d4f_0_cpython
  • python-dateutil=2.8.2=pyhd8ed1ab_0
  • python-lmdb=1.4.1=py38h8dc9893_0
  • python-tzdata=2023.3=pyhd8ed1ab_0
  • python_abi=3.8=3_cp38
  • pytorch=1.13.1=py3.8_cuda11.6_cudnn8.3.2_0
  • pytorch-cuda=11.6=h867d48c_1
  • pytorch-mutex=1.0=cuda
  • pytz=2023.3=pyhd8ed1ab_0
  • pyu2f=0.1.5=pyhd8ed1ab_0
  • pyyaml=6.0=py38h0a891b7_5
  • rdkit=2022.03.5=py38ha829ea6_0
  • re2=2023.03.02=h8c504da_0
  • readline=8.2=h5eee18b_0
  • reportlab=3.6.12=py38h5eee18b_0
  • requests=2.31.0=pyhd8ed1ab_0
  • requests-oauthlib=1.3.1=pyhd8ed1ab_0
  • rsa=4.9=pyhd8ed1ab_0
  • setuptools=68.0.0=pyhd8ed1ab_0
  • six=1.16.0=pyh6c4a22f_0
  • sqlalchemy=2.0.18=py38h01eb140_0
  • sqlite=3.41.2=h5eee18b_0
  • tbb=2021.8.0=hdb19cb5_0
  • tensorboard=2.13.0=pyhd8ed1ab_0
  • tensorboard-data-server=0.7.0=py38h3d167d9_0
  • tk=8.6.12=h1ccaba5_0
  • torchaudio=0.13.1=py38_cu116
  • torchvision=0.14.1=py38_cu116
  • tornado=6.3.2=py38h01eb140_0
  • typing-extensions=4.7.1=hd8ed1ab_0
  • typing_extensions=4.7.1=pyha770c72_0
  • urllib3=2.0.3=pyhd8ed1ab_1
  • werkzeug=2.3.6=pyhd8ed1ab_0
  • wheel=0.40.0=pyhd8ed1ab_0
  • xorg-libxau=1.0.11=hd590300_0
  • xorg-libxdmcp=1.1.3=h7f98852_0
  • xz=5.2.6=h166bdaf_0
  • yaml=0.2.5=h7f98852_2
  • yarl=1.9.2=py38h01eb140_0
  • zipp=3.15.0=pyhd8ed1ab_0
  • zlib=1.2.13=hd590300_5
  • zstd=1.5.2=h3eb15da_6
  • pip:
    • autodocktools-py3==1.5.7.post1+7.g725082b
    • blessed==1.20.0
    • docutils==0.17.1
    • gpustat==1.1
    • jinja2==3.1.2
    • joblib==1.3.1
    • meeko==0.1.dev3
    • mmcif-pdbx==2.0.1
    • numpy==1.20.3
    • nvidia-ml-py==12.535.77
    • pdb2pqr==3.6.1
    • propka==3.5.0
    • psutil==5.9.5
    • pyg-lib==0.2.0+pt113cu116
    • scikit-learn==1.3.0
    • scipy==1.10.1
    • threadpoolctl==3.1.0
    • torch-cluster==1.6.1+pt113cu116
    • torch-geometric==2.2.0
    • torch-scatter==2.1.1+pt113cu116
    • torch-sparse==0.6.17+pt113cu116
    • torch-spline-conv==1.2.2+pt113cu116
    • tqdm==4.65.0
    • vina==1.2.2
    • wcwidth==0.2.6

A Bug with the Estimation of Pocket Size

Hello Jiaqi,

Thanks for your fantastic work! I am using targetdiff and this repo for research purpose. When I read the code for ligand sampling with diffusion process, I suspected that there is a bug with the pocket size estimation:

n_data = batch_size if i < num_batch - 1 else num_samples - batch_size * (num_batch - 1)
batch = Batch.from_data_list([data.clone() for _ in range(n_data)], follow_batch=FOLLOW_BATCH).to(device)

with torch.no_grad():
    
    if sample_num_atoms == 'prior':
        pocket_size = atom_num.get_space_size(batch.protein_pos.detach().cpu().numpy())
        ligand_num_atoms = [atom_num.sample_atom_num(pocket_size).astype(int) for _ in range(n_data)]

The code is copied from scripts/sample_diffusion.py (some of the unimportant lines are omitted). As the paper introduced, it seems that you sampled 100 ligand candidates for 1 pocket, and data here corresponds to 1 pocket, which are copied for n_data times to form a batch. When you estimate the pocket size, the atom_num.get_space_size actually computes the pairwise distance of batch.protein_pos and returns the median of the top-10 pairwise distances. However, I think some of the protein atoms in batch are isolated and the pairwise distance between them do not make sense. I think the pocket_size estimation should be modified as:

pocket_size = atom_num.get_space_size(data.protein_pos.detach().cpu().numpy())

This bug also makes the execution of scripts/sample_diffusion.py or scripts/sample_for_pocket.py very slow, as the pairwise distance computation of ~10000 atoms within a batch is resource hungry. Besides, the estimation of pocket size is always larger than the corrected one.

I am not sure if I am correct. Looking forward to your reply!

Jensen-Shannon divergence

Hello, I have a question about Jensen-Shannon divergence.When calculating JSD, do we calculate the distance between all atoms or only those with bonds.

EGNN re-training

Hi @guanjq , I am not able to replicate the results by re-training the EGNN model . However While using the pre-trained model I am able to replicate the results of RMSE: 1.316, MAE: 1.031, R^2 score: 0.633, Pearson: 0.797, Spearman: 0.782, mean/std: 6.412/1.621.

I tried to keep the all the hyper-parameters and datasets(/data/pdbbind_v2016/pocket_10_refined) by referring to the config found in the checkpoint shared and followed the readme to prepare the pockets and splits

My current results on the test set shared are
RMSE: 3.082, MAE: 2.412, R^2 score: -1.014, Pearson: 0.513, Spearman: 0.562, mean/std: 7.769/3.195

Any Idea what might me going wrong in re-training?

Originally posted by @Dornavineeth in #1 (comment)

Questions about pretrained models.

Thanks for the brilliant work and sharing the code!
I have a question regarding the selection of model parameters for checkpoint: ./pretrained_models/pretrained_diffusion.pt in the sampling.yml file under the config directory. Could you please clarify the criteria used to choose these parameters? Is it based on the checkpoint with the lowest validation loss during training? I have noticed that when using this checkpoint's model(lowest validation loss during training), the performance does not align with the results reported in the associated paper. I would appreciate any insights or guidance on potential factors that I might be overlooking in order to achieve the expected performance.
Thank you!!
Best regards.

result to mol?

Hi, thanks for the code.

result = {
'data': data,
'pred_ligand_pos': pred_pos,
'pred_ligand_v': pred_v,
'pred_ligand_pos_traj': pred_pos_traj,
'pred_ligand_v_traj': pred_v_traj,
'time': time_list
}

how would you get rdkit mol from the results?

the error while running sample_for_pocket

hi, guanjq,

thanks for your good works of new diffusion models.

while I try to run the command line,

python scripts/sample_for_pocket.py configs/sampling.yml --pdb_path examples/1h36_A_rec_1h36_r88_lig_tt_docked_0_pocket10.pdb

the error occured, as below:

[2023-08-23 23:56:23,376::evaluate::INFO] {'model': {'checkpoint': './pretrained_models/pretrained_diffusion.pt'}, 'sample': {'seed': 2021, 'num_samples': 100, 'num_steps': 1000, 'pos_only': False, 'center_pos_mode': 'protein', 'sample_num_atoms': 'prior'}}
[2023-08-23 23:56:27,030::evaluate::INFO] Training Config: {'data': {'name': 'pl', 'path': './data/crossdocked_v1.1_rmsd1.0_pocket10', 'split': './data/crossdocked_pocket10_pose_split.pt', 'transform': {'ligand_atom_mode': 'add_aromatic', 'random_rot': False}}, 'model': {'denoise_type': 'diffusion', 'model_mean_type': 'C0', 'gt_noise_type': 'origin', 'beta_schedule': 'sigmoid', 'beta_start': 1e-07, 'beta_end': 0.002, 'v_beta_schedule': 'cosine', 'v_beta_s': 0.01, 'num_diffusion_timesteps': 1000, 'loss_v_weight': 100.0, 'v_mode': 'categorical', 'v_net_type': 'mlp', 'loss_pos_type': 'mse', 'sample_time_method': 'symmetric', 'time_emb_dim': 0, 'time_emb_mode': 'simple', 'center_pos_mode': 'protein', 'node_indicator': True, 'model_type': 'uni_o2', 'num_blocks': 1, 'num_layers': 9, 'hidden_dim': 128, 'n_heads': 16, 'edge_feat_dim': 4, 'num_r_gaussian': 20, 'knn': 32, 'num_node_types': 8, 'act_fn': 'relu', 'norm': True, 'cutoff_mode': 'knn', 'ew_net_type': 'global', 'r_feat_mode': 'sparse', 'energy_h_mode': 'basic', 'num_x2h': 1, 'num_h2x': 1, 'r_max': 10.0, 'x2h_out_fc': False, 'sync_twoup': False}, 'train': {'seed': 2021, 'batch_size': 4, 'num_workers': 4, 'max_iters': 10000000, 'val_freq': 2000, 'pos_noise_std': 0.1, 'max_grad_norm': 8.0, 'bond_loss_weight': 1.0, 'optimizer': {'type': 'adam', 'lr': 0.0005, 'weight_decay': 0, 'beta1': 0.95, 'beta2': 0.999}, 'scheduler': {'type': 'plateau', 'factor': 0.6, 'patience': 10, 'min_lr': 1e-06}}}
[2023-08-23 23:56:32,218::evaluate::INFO] Successfully load the model! [./pretrained_models/pretrained_diffusion.pt](https://file+.vscode-resource.vscode-cdn.net/d%3A/Cheminfo_Workshop/4_Fragment_Scaffold_Evolution/targetdiff-main/pretrained_models/pretrained_diffusion.pt)
Traceback (most recent call last):
  File "scripts/sample_for_pocket.py", line 75, in <module>
    data = transform(data)
  File "c:\Users\lsy\anaconda3\envs\molgen\lib\site-packages\torch_geometric\transforms\compose.py", line 24, in __call__
    data = transform(data)
  File "d:\Cheminfo_Workshop\4_Fragment_Scaffold_Evolution\targetdiff-main\scripts\utils\transforms.py", line 128, in __call__
    amino_acid = F.one_hot(data.protein_atom_to_aa_type, num_classes=self.max_num_aa)
RuntimeError: one_hot is only applicable to index tensor.

could you please help to provde suggestions how to fix it up?
many thanks,

Best,
Sh-Y

Questions about data preprocessing

Hi, it's nice of you to share codes!
I got some puzzles about data preprocessing, that is, how to generate the file you've shared in google disk 'crossdocked_v1.1_rmsd1.0_pocket10_processed_final.lmdb'.
I've go through the codes but didn't find the code to directly implement this. I notice the code for data preprocessing can generate index.pkl and crossdocked_pocket10_pose_split.pt, then how can I generate the lmdb file using these files?
I'm new to this area and it's highly appreciated if you can help me with this! Thanks a lot!

Issue with Vina Docking During Molecule Evaluation

Firstly, I would like to express my gratitude for your impressive work and for making it available on GitHub.

I recently utilized your pre-trained model to generate molecules for a specific pocket using the following command: python scripts/sample_for_pocket.py configs/sampling.yml --pdb_path examples/1zcm.pdb. However, I encountered an issue when it came to the vina docking step during evaluation by scripts/evaluate_diffusion.py . Specifically, the 'ligand_filename': r['data'].ligand_filename, does not appear to exist in the results generated by sample_for_pocket.py.

In order to resolve this issue, I removed the aforementioned snippet from lines 110 and 143 in the scripts/evaluate_diffusion.py code. Post this modification, the problem seems to be resolved. I am raising this issue to bring this to your attention, and potentially help others who might encounter a similar problem.

The dimension of time embedding

In configs/training.yml, time_emb_dim is set to 0, does this number correspond to the pretrained_difusion.pt you provided?

ZeroDivisionError in evaluation

Hi,

Thank you for sharing this fantastic work. I am trying to reproduce some experiment results, and I follow the instruction to download all the data and use the provided pretrained weight. Here is an image regarding the error in the evaluation.

Could you help to view this problem and give me some hints to fix it?

By the way, complete mols' JS bond distances also show None, which I don't think it is correct.

image

Many thanks!

Reopening issue of not having pocket10 files for training

I am reopening this issue #3 since the data folder on github has no folder called crossdocked_v1.1_rmsd1.0_pocket10. What is supposed to be in this folder? Should it have the same folder structure as this folder that is in the github link: crossdocked_v1.1_rmsd1.0 which you get after untarring?

crossdocked_v1.1_rmsd1.0
├── 1433B_HUMAN_1_240_pep_0
│   ├── 5f74_A_rec_5f74_amp_lig_tt_docked_5.sdf
│   ├── 5f74_A_rec.pdb
│   ├── 5n10_A_rec_5f74_amp_lig_tt_min_0.sdf
│   └── 5n10_A_rec.pdb
├── 1433C_TOBAC_1_256_0
│   ├── 2o98_B_rec_2o98_fsc_lig_tt_docked_0.sdf
│   └── 2o98_B_rec.pdb
├── 1433S_HUMAN_1_233_0
│   ├── 3iqu_A_rec_3p1s_fsc_lig_tt_docked_10.sdf
│   ├── 3iqu_A_rec.pdb
│   ├── 3iqv_A_rec_3iqv_fsc_lig_tt_docked_0.sdf
│   ├── 3iqv_A_rec_3p1o_fsc_lig_tt_docked_0.sdf
│   ├── 3iqv_A_rec_3p1q_fsc_lig_tt_docked_0.sdf
│   ├── 3iqv_A_rec_3p1q_fsc_lig_tt_docked_5.sdf
│   ├── 3iqv_A_rec_3p1s_fsc_lig_tt_docked_0.sdf
│   ├── 3iqv_A_rec_3smk_cw7_lig_tt_docked_1.sdf
│   ├── 3iqv_A_rec_3smm_fja_lig_tt_docked_0.sdf
│   ├── 3iqv_A_rec_3smo_fja_lig_tt_docked_2.sdf
│   ├── 3iqv_A_rec_3smo_fja_lig_tt_min_0.sdf
│   ├── 3iqv_A_rec_3sp5_cx7_lig_tt_docked_0.sdf
│   ├── 3iqv_A_rec_4dhn_0kc_lig_tt_min_0.sdf
│   ├── 3iqv_A_rec_4fr3_0v4_lig_tt_docked_2.sdf
│   ├── 3iqv_A_rec_4jdd_fsc_lig_tt_docked_1.sdf
│   ├── 3iqv_A_rec_5mxo_fsc_lig_tt_docked_0.sdf
│   ├── 3iqv_A_rec.pdb

how to get the batch.ligand_element_batch

Hi, thank you for sharing such a good work. However, I am a little confused about how can I get batch.ligand_element_batch in the
def train(it):
model.train()
optimizer.zero_grad()
for _ in range(config.train.n_acc_batch):
batch = next(train_iterator).to(args.device)

        results = model.get_diffusion_loss(
            ligand_pos=batch.ligand_pos, #
            ligand_v=batch.ligand_atom_feature_full,
            batch_ligand=batch.ligand_element_batch
        )

Can you tell where can I find the processing operation of the ligand element batch?
Thank you

settings in training.yml

Great work!
@guanjq a couple of questions:

  1. in provided training.yml, the ligand_atom_mode is "add_aromatic" and random_rot "False". Are these settings of the provided pretrained model?
  2. Have you tried "full" and "True", respectively? And will "full" and "True" lead to better effect?

#############
data:
...
transform:
ligand_atom_mode: add_aromatic
random_rot: False
##############

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.