Giter Club home page Giter Club logo

midi's Introduction

MiDi: Mixed Graph and 3D Denoising Diffusion for Molecule Generation

Clément Vignac*, Nagham Osman*, Laura Toni, Pascal Frossard

ECML 2023

Installation

This code was tested with PyTorch 2.0.1, cuda 11.8 and torch_geometric 2.3.1 on multiple gpus.

  • Download anaconda/miniconda if needed

  • Create a rdkit environment that directly contains rdkit:

    conda create -c conda-forge -n midi rdkit=2023.03.2 python=3.9

  • conda activate midi

  • Check that this line does not return an error:

    python3 -c 'from rdkit import Chem'

  • Install the nvcc drivers for your cuda version. For example:

    conda install -c "nvidia/label/cuda-11.8.0" cuda

  • Install a corresponding version of pytorch, for example:

    pip3 install torch==2.0.1 --index-url https://download.pytorch.org/whl/cu118

  • Install other packages using the requirement file:

    pip install -r requirements.txt

  • Run:

    pip install -e .

Datasets

Training:

First move inside the src folder (so that the outputs are saved at the right location):

Some examples:

QM9 without hydrogens on cpu

python3 main.py dataset=qm0 dataset.remove_h=True +experiment=qm9_no_h

GEOM-DRUGS with hydrogens on 2 gpus

python3 main.py dataset=geom dataset.remove_h=False +experiment=geom_with_h general.gpus=2

Resuming a previous run

First, retrieve the absolute path of the checkpoint, it looks like ABS_PATH=/home/vignac/MiDi/outputs/2023-02-13/18-10-49-geomH/checkpoints/geomH_bigger/epoch=219.ckpt'

Then run:

python3 main.py dataset=qm0 dataset.remove_h=True +experiment=qm9_no_h general.resume='ABS_PATH'

Evaluation

Sampling on multiple gpu is not really handled, we recommand sampling on a single gpu.

Run:

python3 main.py dataset=qm0 dataset.remove_h=True +experiment=qm9_no_h general.test_only='ABS_PATH'

Checkpoints

QM9 implicit H:

  • command: python3 main.py dataset=qm9 dataset.remove_h=True +experiment=qm9_with_h_uniform
  • checkpoint: missing

QM9 explicit H:

Geom implicit H:

Geom explicit H:

Generated samples

QM9 implicit H:

QM9 explicit H:

Geom with explicit H:

Evaluate your model on the proposed metrics

To benchmark your own model with the proposed metrics, you can use the sampling_metrics function in src/metrics/molecular_metrics.py: sampling_metrics(molecules=molecule_list, name='my_method', current_epoch=-1, local_rank=0).

You'll need to write a few lines to load your generated graphs and create a list of Molecule objects (in src/analysis/rdkit_functions.py).

Use MiDi on a new dataset

To implement a new dataset, you will need to create a new file in the src/datasets folder. This file should implement a Dataset class, a Datamodule class and and Infos class. Check qm9_dataset.py and geom_dataset.py for examples.

Once the dataset file is written, the code in main.py can be adapted to handle the new dataset, and a new file can be added in configs/dataset.

Use OpenBabel for baseline results

  • In this work, we use Open Babel GUI for bond prediction.
  • Install OpenBabel that corresponds to the machine you have. You can download it using the following link.
  • For the input format, you need to choose "xyz -- XYZ cartesian coordinates format".
  • For the output format, you need to choose "sdf -- MDL MOL format".
  • In the additional instructi ons window, write the word "end" in the section "Add or replace molecule title".
  • Choose all the xyz files you want to do the bond prediction for in the input section
  • Choose the directory where you want to save the output file, then click on Convert.
  • You can then use the function open_babel_eval in midi/analysis/baselines_evaluation which requires the path as argument.

Cite this paper

@article{vignac2023midi,
  title={MiDi: Mixed Graph and 3D Denoising Diffusion for Molecule Generation},
  author={Vignac, Clement and Osman, Nagham and Toni, Laura and Frossard, Pascal},
  journal={arXiv preprint arXiv:2302.09048},
  year={2023}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.