Giter Club home page Giter Club logo

Comments (5)

ShayneWierbowski avatar ShayneWierbowski commented on September 10, 2024

I'm just playing around with this repository myself and was also interested in this question (so I'm not extensively familiar with the code). It'd certainly be better to have an official response here, but from what I can tell you should just need to modify the reconstruct method inside the HierVAE model.

    def reconstruct(self, batch):
        # I think these should match the output from the preprocess.py script
        graphs, tensors, _ = batch
        
        # Reformat as tensors (from numpy arrays?)
        tree_tensors, graph_tensors = tensors = make_cuda(tensors)
        
        # Encode batch of compounds
        root_vecs, tree_vecs, _, graph_vecs = self.encoder(tree_tensors, graph_tensors)
        
        # Modify the root_recs embeddings? (Not actually sure what's happening between these two steps)
        # But the output root_vecs here should be the latent embeddings as far as I can tell (so if you just return
        # these instead of running the next step, this should be what you want)
        root_vecs, root_kl = self.rsample(root_vecs, self.R_mean, self.R_var, perturb=False)
        
        # Convert the laten embedding(s) back into SMILES string(s)
        return self.decoder.decode((root_vecs, root_vecs, root_vecs), greedy=True, max_decode_step=150)

If you aren't sure how to get everything in the right format for the input to the reconstruct method, I wrote up a test case to read in a model from the checkpoint, take a given smiles string, encode it, and then decode it.

Again, I'm not 100% confident this is accurate, my attempted encode-decode process did not reproduce the same compound, so it's possible I'm misunderstanding something here.

# Imports / Arg Parser / Functions
# Copied from generate.py preprocess.py and / or hgraph/hgnn.py

from multiprocessing import Pool
import math, random, sys
import pickle
import argparse
from functools import partial
import torch
import numpy

from hgraph import *
import rdkit

def make_cuda(tensors):
    tree_tensors, graph_tensors = tensors
    make_tensor = lambda x: x if type(x) is torch.Tensor else torch.tensor(x)
    tree_tensors = [make_tensor(x).long() for x in tree_tensors[:-1]] + [tree_tensors[-1]]
    graph_tensors = [make_tensor(x).long() for x in graph_tensors[:-1]] + [graph_tensors[-1]]
    return tree_tensors, graph_tensors

def to_numpy(tensors):
    convert = lambda x : x.numpy() if type(x) is torch.Tensor else x
    a,b,c = tensors
    b = [convert(x) for x in b[0]], [convert(x) for x in b[1]]
    return a, b, c

def tensorize(mol_batch, vocab):
    x = MolGraph.tensorize(mol_batch, vocab, common_atom_vocab)
    return to_numpy(x)

parser = argparse.ArgumentParser()
parser.add_argument('--vocab', required=True)
parser.add_argument('--atom_vocab', default=common_atom_vocab)
parser.add_argument('--model', required=True)

parser.add_argument('--seed', type=int, default=7)
parser.add_argument('--nsample', type=int, default=10000)

parser.add_argument('--rnn_type', type=str, default='LSTM')
parser.add_argument('--hidden_size', type=int, default=250)
parser.add_argument('--embed_size', type=int, default=250)
parser.add_argument('--batch_size', type=int, default=50)
parser.add_argument('--latent_size', type=int, default=32)
parser.add_argument('--depthT', type=int, default=15)
parser.add_argument('--depthG', type=int, default=15)
parser.add_argument('--diterT', type=int, default=1)
parser.add_argument('--diterG', type=int, default=3)
parser.add_argument('--dropout', type=float, default=0.0)

args = parser.parse_args()

# Parse Vocabuary File
vocab = [x.strip("\r\n ").split() for x in open(args.vocab)]
args.vocab = PairVocab(vocab)

# Test Compound to reconstruct
smiles = ['C=Cc1cccc(C(=O)N2CC(c3ccc(F)cc3)C(C)(C)C2)c1']

print("\nINPUT SMILES: {0}\n".format(" ".join(smiles)))

# Convert SMILES String into MolGraph Tree / Graph Tensors
# (See preprocess.py)
o = tensorize(smiles, args.vocab)
batches, tensors, all_orders = o

# Extract pieces we need
tree_tensors, graph_tensors = make_cuda(tensors)


# Load Checkpoint model
model = HierVAE(args)

model.load_state_dict(torch.load(args.model, map_location=torch.device('cpu'))[0])
model.eval()


# Encode compound
root_vecs, tree_vecs, _, graph_vecs = model.encoder(tree_tensors, graph_tensors)

print("\nLATENT_EMBEDDING\n")
print(root_vecs)
print("\n")

# Unsure what this second step does / what the difference between
# the first and second root_vecs values are?
root_vecs, root_kl = model.rsample(root_vecs, model.R_mean, model.R_var, perturb=False)

print("\nLATENT_EMBEDDING_2\n")
print(root_vecs)
print("\n")


# Decode compound
decoded_smiles = model.decoder.decode((root_vecs, root_vecs, root_vecs), greedy=True, max_decode_step=150)


# The decoded and original smiles / compound do not match
# Not sure if this is because something is done wrong or just
# because this compound is one that couldn't be reconstructed
# accurately
print("DECODED SMILES: {0}".format("".join(decoded_smiles)))

from hgraph2graph.

ShayneWierbowski avatar ShayneWierbowski commented on September 10, 2024

It's possible the reason the original vs. decoded SMILES strings don't match up is related to underlying rdkit implementations referenced in other issues (#20).

When I expand this to test multiple compounds I start getting two errors I believe are linked to this...

Traceback (most recent call last):
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/test_smiles2tensor.py", line 70, in <module>
    o = tensorize(smiles, args.vocab)
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/test_smiles2tensor.py", line 29, in tensorize
    x = MolGraph.tensorize(mol_batch, vocab, common_atom_vocab)
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/mol_graph.py", line 152, in tensorize
    mol_batch = [MolGraph(x) for x in mol_batch]
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/mol_graph.py", line 152, in <listcomp>
    mol_batch = [MolGraph(x) for x in mol_batch]
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/mol_graph.py", line 22, in __init__
    self.order = self.label_tree()
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/mol_graph.py", line 123, in label_tree
    tree.nodes[i]['assm_cands'] = get_assm_cands(mol, hist, inter_label, pa_cls, len(inter_atoms))
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/chemutils.py", line 120, in get_assm_cands
    mol = get_clique_mol(mol, atoms)
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/chemutils.py", line 111, in get_clique_mol
    smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
rdkit.Chem.rdchem.KekulizeException: Can't kekulize mol.  Unkekulized atoms: 22

and

Traceback (most recent call last):
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/test_smiles2tensor.py", line 70, in <module>
    o = tensorize(smiles, args.vocab)
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/test_smiles2tensor.py", line 29, in tensorize
    x = MolGraph.tensorize(mol_batch, vocab, common_atom_vocab)
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/mol_graph.py", line 153, in tensorize
    tree_tensors, tree_batchG = MolGraph.tensorize_graph([x.mol_tree for x in mol_batch], vocab)
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/mol_graph.py", line 194, in tensorize_graph
    fnode[v] = vocab[attr]
  File "/mnt/data-rs-wihy299nas/sdw95/Drug_Project/Molecule_Encoder_Decoder/hgraph2graph_nocuda/hgraph/vocab.py", line 43, in __getitem__
    return self.hmap[x[0]], self.vmap[x]
KeyError: 'C1=CNN=C1'

Will try to get the package versions corrected and confirm if this is the problem.

from hgraph2graph.

michal-pikusa avatar michal-pikusa commented on September 10, 2024

Thank you @ShayneWierbowski! Your code actually works pretty nice, I think.

I've been able to encode a set of molecules I work with (around 4,000 of them), and reconstruct them back with your code, and for my set ~ 20% are reconstructed correctly 1:1. The rest is not 1:1, but Tanimoto similarity shows me that median similarity is ~0.8 meaning that most molecules still are very similar to the original with minor modifications. This can probably be improved with the training hyperparameters, as I see a similar thing with their previous work (JT-VAE).

As to your error, make sure you are trying to encode molecules that are part of your training set, or have all the relevant motifs from the vocabulary. You cannot encode something that is out of vocabulary, and I think that's why you are getting the KeyError. I haven't gotten a single one while encoding all molecules from my training set.

Thanks again for your help.

from hgraph2graph.

ShayneWierbowski avatar ShayneWierbowski commented on September 10, 2024

@michal-pikusa I'm glad you found this helpful!

Thanks for your suggestion about the vocabulary as well. This is definitely an important consideration. In my case I think it was linked to the RDKit version and I was able to get everything running smoothly after reinstalling a correct version.

From my expanded evaluation I came up with similar results as you (~20% perfectly reconstructed, the rest with generally high Tanimoto similarity).

From your experience retraining the model and / or tweaking hyper-parameters for a different drug / vocabulary set do you have a sense of how long the training takes? I haven't played with that yet.

from hgraph2graph.

michal-pikusa avatar michal-pikusa commented on September 10, 2024

@ShayneWierbowski:
Training on my 4k set took ~ 30 minutes on a single GPU with 16GB RAM, so it's really fast.

from hgraph2graph.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.