Giter Club home page Giter Club logo

exponential-snr-diffusion's Introduction

Exponential-SNR Diffusion

This repository contains the implementation of Exponential-SNR Diffusion. This model aims to enhance the signal-to-noise ratio (SNR) exponentially over time to improve the performance of diffusion models.

Setup Instructions

To set up the environment for this project, follow these steps:

  1. Create a Conda Environment:

    conda create -n diffusion python=3.8
  2. Activate the Conda Environment:

    conda activate diffusion

Repository Structure

  • improved_diffusion/: Contains the source code for the Exponential-SNR Diffusion model.
  • datasets/: Includes scripts to prepare datasets CIFAR-10.
  • notebooks/: Jupyter notebooks for experimentation and visualization.
  • scripts/: Shell and Python scripts for various utilities and tasks.

Besides, you can download trained checkpoints and sampled images from this address

Getting Started

  1. Clone the Repository:

    git clone https://github.com/HoangAnAIRS/exponential-SNR-diffusion.git
    cd exponential-SNR-diffusion
  2. Install Dependencies:

    Ensure you have activated the diffusion environment:

    conda activate diffusion

    Then, install the required packages:

    pip install -e . 
    pip install -r requirements.txt
  3. Training the Model:

    To train the model, use the following script:

    CUDA_VISIBLE_DEVICES=1
    python3 image_train.py \
    --data_dir /home/admin/workspace/user/improved-diffusion/datasets/cifar_train \
    --image_size 32 \
    --num_channels 128 \
    --num_res_blocks 3 \
    --learn_sigma False \
    --dropout 0.3 \
    --diffusion_steps 1000 \
    --noise_schedule linear \
    --lr 1e-4 \
    --batch_size 128 \
    --schedule_sampler "early" \
    --log_dir /home/admin/workspace/user/improved-diffusion/logs_early_snr_clamp_5.0

    or get inside scripts folder, change configs in sample.sh file and run:

    bash train.sh
  4. Sampling from the Model:

    To generate samples from the trained model, use the provided script:

    python3 image_sample.py \
    --image_size 32 \
    --num_channels 128 \
    --num_res_blocks 3 \
    --dropout 0.3 \
    --diffusion_steps 1000 \
    --noise_schedule linear \
    --timestep_respacing ddim250 \
    --use_ddim True \
    --device cuda:2 \
    --model_path /home/admin/workspace/user/improved-diffusion/logs_early_snr_x_start/ema_0.9999_080000.pt \
    --output_npz_dir /home/admin/workspace/user/improved-diffusion/sampled_images_early_snr_x_start_ddim250

    or get inside scripts folder, change configs in sample.sh file and run:

    bash sample.sh
  5. Batch Sampling with Varying Models:

    To perform batch sampling using different models, use the following script:

    ./batch_sample_varies_model.sh

    This script loops through each model and generates samples, storing the results in the specified output directory.

  6. Batch Sampling with Varying DDIM:

    To perform batch sampling with varying DDIM timesteps, use the following script:

    ./batch_sample_varies_ddim.sh

    This script loops through different timestep respacing values and generates samples for each, storing the results in the respective output directories.

Contact

For any questions or inquiries, please contact Hoang An AIRS.

exponential-snr-diffusion's People

Contributors

hoanganairs avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.