Giter Club home page Giter Club logo

guided_diffusion_art_gen's Introduction


CLIP Diffusion Art

Fine-tune diffusion models on custom datasets and sample with text-conditioning using CLIP guidance and SwinIR for super resolution.

๐Ÿ“Œ Dataset with public domain artworks created for this project:

โ€ƒArtworks in Public Domain

๐Ÿ“Œ Link to interactive run in notebook:

โ€ƒStunning Art with CLIP Guided Diffusion+SwinIR

๐Ÿ“Œ Wandb logging is integrated for training and sampling.


Generated Samples




"vibrant watercolor painting of a flower, artstation HQ"




"beautiful matte painting of dystopian city, Behance HD"




"vibrant watercolor painting of a flower, artstation HQ"




"artstation HQ, photorealistic depiction of an alien city"


For more generated artworks, visit this report


Super-resolution Results





Credits

Developed using techniques and architectures borrowed from original work by the authors below:

Huge thanks to all their great work! I highly recommend checking out these repos.


Installation

git clone https://github.com/sreevishnu-damodaran/clip-diffusion-art.git -q
cd clip-diffusion-art
pip install -e . -q
git clone https://github.com/JingyunLiang/SwinIR.git -q
git clone https://github.com/crowsonkb/guided-diffusion -q
pip install -e guided-diffusion -q
git clone https://github.com/openai/CLIP -q
pip install -e ./CLIP -q

Dataset

Public Domain Artworks dataset used in this repo:

https://www.kaggle.com/sreevishnudamodaran/artworks-in-public-domain

Additional details datasets/README.md


Training & Fine-tuning

Chooose the hyperparameters for training. These are resonable defaults to fine-tune on a custom dataset with a 16GB GPUs on Colab or Kaggle:

MODEL_FLAGS="--image_size 256 --num_channels 128 --num_res_blocks 2 --num_heads 1 --attention_resolutions 16"
DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear --learn_sigma True --rescale_learned_sigmas True --rescale_timesteps True --use_scale_shift_norm False"
TRAIN_FLAGS="--lr 5e-6 --save_interval 500 --batch_size 16 --use_fp16 True --wandb_project diffusion-art-train --use_checkpoint True --resume_checkpoint pretrained_models/lsun_uncond_100M_1200K_bs128.pt"

Once the hyperparameters are set, run the traning job as follows:

python clip_diffusion_art/train.py --data_dir path/to/images $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

Refer to the openai improved diffusion for more details on choosing hyperparameters and to select other pre-trained weights.


Download SR pre-trained weights

wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth

Passing the sr_model_path flag to sample.py performs super-resolution to each image after sampling.

Sample Images with CLIP Guidance

python clip_diffusion_art/sample.py \
"beautiful matte painting of dystopian city, Behance HD" \
--checkpoint 256x256_clip_diffusion_art.pt \
--model_config "clip_diffusion_art/configs/256x256_clip_diffusion_art.yaml" \
--sampling "ddim50" \
--cutn 60 \
--cut_batches 4 \
--sr_model_path pretrained_models/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth \
--large_sr \
--output_dir "outputs"

Options:

--images - image prompts (default=None)
--checkpoint - diffusion model checkpoint to use for sampling
--model_config - diffusion model config yaml
--wandb_project - enable wandb logging and use this project name
--wandb_name - optinal run name to use for wandb logging
--wandb_entity - optinal entity to use for wandb logging
--num_samples - - number of samples to generate (default=1)
--batch_size - default=1batch size for the diffusion model
--sampling - timestep respacing sampling methods to use (default="ddim50", choices=[25, 50, 100, 150, 250, 500, 1000, ddim25, ddim50, ddim100, ddim150, ddim250, ddim500, ddim1000])
--diffusion_steps - number of diffusion timesteps (default=1000)
--skip_timesteps - diffusion timesteps to skip (default=5)
--clip_denoised - enable to filter out noise from generation (default=False)
--randomize_class_disable - disables changing imagenet class randomly in each iteration (default=False)
--eta - the amount of noise to add during sampling (default=0)
--clip_model - CLIP pre-trained model to use (default="ViT-B/16", choices=["RN50","RN101","RN50x4","RN50x16","RN50x64","ViT-B/32","ViT-B/16","ViT-L/14"])
--skip_augs - enable to skip torchvision augmentations (default=False)
--cutn - the number of random crops to use (default=16)
--cutn_batches - number of crops to take from the image (default=4)
--init_image - init image to use while sampling (default=None)
--loss_fn - loss fn to use for CLIP guidance (default="spherical", choices=["spherical" "cos_spherical"])
--clip_guidance_scale - CLIP guidance scale (default=5000)
--tv_scale - controls smoothing in samples (default=100)
--range_scale - controls the range of RGB values in samples (default=150)
--saturation_scale - controls the saturation in samples (default=0)
--init_scale - controls the adherence to the init image (default=1000)
--scale_multiplier - scales clip_guidance_scale tv_scale and range_scale (default=50)
--disable_grad_clamp - disable gradient clamping (default=False)
--sr_model_path - SwinIR super-resolution model checkpoint (default=None)
--large_sr - enable to use large SwinIR super-resolution model (default=False)
--output_dir - output images directory (default="output_dir")
--seed - the random seed (default=47)
--device - the device to use


Apply Super-resolution

Use the following to run super-resolution on other images or use it for other tasks (grayscale/color image denoising/JPEG compression artifact reduction)

python swinir.py <path-to-images-dir> --task "real_sr"

data_dir - directory with images

--task - image restoration task (default='real_sr', choices=['real_sr', 'color_dn', 'gray_dn', 'jpeg_car'])

guided_diffusion_art_gen's People

Contributors

snowstache avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.