Giter Club home page Giter Club logo

discrete_dppm_graphs's Introduction

Diffusion Models for Graphs Benefit From Discrete State Spaces

This is the official repository for the paper "Diffusion Models for Graphs Benefit From Discrete State Spaces". Link to the paper: https://arxiv.org/abs/2210.01549

Installation

Use package manager Anaconda or Miniconda to generate a environment with the corresponding dependencies. You may use the following command to do so:

conda env create -f environment.yml
conda activate DISCDDPM

You should have activated your conda environment and are ready to run the project.

Running the code

As evaluated in the paper, there are 3 different implementations of this project. For each implementation there exists a training file and a sampling file. Additionally to recreate the tests presented in the paper, we include the gridsearch files we used to run the different seeds and datasets as slurm jobs.

Training directly as a Python Job

If you do not want to do any gridsearch but use the model to generate graphs, then you may use the commands generated by the gridsearch scripts directly instead of writing them to a Slurm .sh file. We have provided fully working config files in the /config directory.

All results of the training and the followed sampling (models, configs, output files etc. ) are stored in a newly generated directory "./exp/modelname_dataname_month_time_randomnumber/...".

Sampling

If you wish to sample a already finished and trained model, you can do so manually with the jupyter notebook sample.ipynb. This also lets you print the sampled results over several noiselevels which helps with some insight into the effect of the amount of samplesteps reasonable.

Running as a gridsearch Slurm job to reproduce results

For the Slurm service to work please configure the file scripts/gridsearch.sh to your slurm specifications. Depending on which of the three implementations you wish to run, open the according gridsearch_... file and edit your hyperparameters within the first paragraph as described in the comments. Additionally you must change the line market with "ATTENTION" to match your Slurm serive command.
This python script generate needed config files in the directory "./config/gridsearch/consec_modelname_datetime/... ", based on the parameters in the template files (config/edp_final.yaml & config/ppgn_final.yaml) and the specified values in the gridsearch file.

Note:

Also, we used wandb for tracking our results. If you want to use this make sure you add your wandb key in the training files in the "train_main" function along with your username or if you do not wish to use wandb simply remove all the mentions of wandb.log() in the training files and comment out the respective lines.

discrete_dppm_graphs's People

Contributors

kilianhae avatar

Stargazers

Ayu Chan avatar  avatar  avatar  avatar Tengfei_Matthew avatar ruomengd avatar Jack Langerman avatar Rui avatar Manuel Ruiz avatar Al avatar Suyu Liu avatar Zixuan YI avatar Xi Wang avatar  avatar Charles Dufour avatar Chang Liu avatar  avatar YuxuanLei avatar  avatar Anping avatar yueliu1999 avatar Qi Yan avatar Zhao Yang avatar Yuan-Man avatar Chin-Yun Yu avatar MuhammadAnwar avatar

Watchers

 avatar

discrete_dppm_graphs's Issues

Question about running the code directly

Hi, thanks for sharing the code. I tried to run the code directly without slurm, and the command is like "python ppgn_simple.py -c ./config/train_ego_small_ddpm_16.yaml". However, I found that some args using in ppgn_simple.py is missed in the config file, for example, config.noiselevels. Could you please provide some guidance to fix it? I don't know if it's because the remaining setting is done in files related to gridsearch. Thanks in advance.

Question about the two components of loss_func_kld in ppgn_vlb.py

Hi!
Applying D3PM to the graph area is really brilliant work and the derivations in your paper are clear and easy to follow. When reading your implementation, the two components of KL divergence loss bother me, and I just cannot figure it out.

In ppgn_vlb.py:
I understand that the q in line 93 is actually the implementation of Equ. (2) q(A_{t-1}|A_t, A_0). And the q in line 93 together with grad_inv in line 116 represents the q when A_{t-1}=1 or 0, respectively. (Please let me know if my understanding here is wrong)

**My question is what is the posterior p from line 95 to line 102? and the one from line 105 to line 109 separately? **
Why do you sum p with the weight: score and (1- score) in line 109?
In my understanding, the predicted posterior p is the output of your model and has nothing to do with the q-like computation (mult1 * multi2 / div), why the predicted posterior p is somehow computed the same way?

Looking forward to your reply.
James

Question on ppgn_simple BCE loss implementation

Hi,

I really appreciate for your code sharing on this fantastic work. However, I am confused at this line in ppgn_simple.py file. Based on your paper, when model is trained with the simple loss target, the model will calculate the cross entropy loss between A0 and nnθ(At). However, the code calculates the cross entropy loss between grad_log_noise and nnθ(At). Why is that?

Looking forward to your reply.

Hang

Question about Modify the number of PPGN model's diffusion steps. 😵

This is a great piece of work! (๑•̀ㅂ•́)و✧
My questions:
How to modify the number of PPGN model's diffusion steps (that you mentioned in paper that it is 32-64) and PPGN_ Sample's diffusion steps.
Just tell me which line in the. yml or. py file to modify.
Thank you very much for your patience and help!♥

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.