Giter Club home page Giter Club logo

diffusion-fast's Introduction

Diffusion, fast

Repository for the blog post: Accelerating Generative AI Part III: Diffusion, Fast. You can find a run down of the techniques on the ๐Ÿค— Diffusers website too.


Summary of the optimizations:

  • Running with the bfloat16 precision
  • scaled_dot_product_attention (SPDA)
  • torch.compile
  • Combining q,k,v projections for attention computation
  • Dynamic int8 quantization

These techniques are fairly generalizable to other pipelines too, as we show below.

Table of contents:

Setup ๐Ÿ› ๏ธ

We rely on pure PyTorch for the optimizations. You can refer to the Dockerfile to get the complete development environment setup.

For hardware, we used an 80GB 400W A100 GPU with its memory clock set to the maximum rate (1593 in our case).

Running a benchmarking experiment ๐ŸŽ๏ธ

run_benchmark.py is the main script for benchmarking the different optimization techniques. After an experiment has been done, you should expect to see two files:

  • A .csv file with all the benchmarking numbers.
  • A .jpeg image file corresponding to the experiment.

Refer to the experiment-scripts/run_sd.sh for some reference experiment commands.

Notes on running PixArt-Alpha experiments:

  • Use the run_experiment_pixart.py for this.
  • Uninstall the current installation of diffusers and re-install it again like so: pip install git+https://github.com/huggingface/diffusers@fuse-projections-pixart.
  • Refer to the experiment-scripts/run_pixart.sh script for some reference experiment commands.

(Support for PixArt-Alpha is experimental.)

You can use the prepare_results.py script to generate a consolidated CSV file and a plot to visualize the results from it. This is best used after you have run a couple of benchmarking experiments already and have their corresponding CSV files.

To run the script, you need the following dependencies:

  • pandas
  • matplotlib
  • seaborn

Improvements, progressively ๐Ÿ“ˆ ๐Ÿ“Š

Baseline
from diffusers import StableDiffusionXLPipeline

# Load the pipeline in full-precision and place its model components on CUDA.
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0"
).to("cuda")

# Run the attention ops without efficiency.
pipe.unet.set_default_attn_processor()
pipe.vae.set_default_attn_processor()

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]

With this, we're at:

Bfloat16
from diffusers import StableDiffusionXLPipeline
import torch

pipe = StableDiffusionXLPipeline.from_pretrained(
	"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

# Run the attention ops without efficiency.
pipe.unet.set_default_attn_processor()
pipe.vae.set_default_attn_processor()

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]

๐Ÿ’ก We later ran the experiments in float16 and found out that the recent versions of torchao do not incur numerical problems from float16.

scaled_dot_product_attention
from diffusers import StableDiffusionXLPipeline
import torch

pipe = StableDiffusionXLPipeline.from_pretrained(
	"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]
torch.compile

First, configure some compiler flags:

from diffusers import StableDiffusionXLPipeline
import torch

# Set the following compiler flags to make things go brrr.
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

Then load the pipeline:

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

Compile and perform inference:

# Compile the UNet and VAE.
pipe.unet.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

# First call to `pipe` will be slow, subsequent ones will be faster.
image = pipe(prompt, num_inference_steps=30).images[0]
Combining attention projection matrices
from diffusers import StableDiffusionXLPipeline
import torch

# Configure the compiler flags.
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

# Combine attention projection matrices.
pipe.fuse_qkv_projections()

# Compile the UNet and VAE.
pipe.unet.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

# First call to `pipe` will be slow, subsequent ones will be faster.
image = pipe(prompt, num_inference_steps=30).images[0]
Dynamic quantization

Start by setting the compiler flags (this time, we have two new):

from diffusers import StableDiffusionXLPipeline
import torch

from torchao.quantization import apply_dynamic_quant, swap_conv2d_1x1_to_linear

# Compiler flags. There are two new.
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

Then write the filtering functions to apply dynamic quantization:

def dynamic_quant_filter_fn(mod, *args):
    return (
        isinstance(mod, torch.nn.Linear)
        and mod.in_features > 16
        and (mod.in_features, mod.out_features)
        not in [
            (1280, 640),
            (1920, 1280),
            (1920, 640),
            (2048, 1280),
            (2048, 2560),
            (2560, 1280),
            (256, 128),
            (2816, 1280),
            (320, 640),
            (512, 1536),
            (512, 256),
            (512, 512),
            (640, 1280),
            (640, 1920),
            (640, 320),
            (640, 5120),
            (640, 640),
            (960, 320),
            (960, 640),
        ]
    )


def conv_filter_fn(mod, *args):
    return (
        isinstance(mod, torch.nn.Conv2d) and mod.kernel_size == (1, 1) and 128 in [mod.in_channels, mod.out_channels]
    )

Then we're rwady for inference:

pipe = StableDiffusionXLPipeline.from_pretrained(
	"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

# Combine attention projection matrices.
pipe.fuse_qkv_projections()

# Change the memory layout.
pipe.unet.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

# Swap the pointwise convs with linears.
swap_conv2d_1x1_to_linear(pipe.unet, conv_filter_fn)
swap_conv2d_1x1_to_linear(pipe.vae, conv_filter_fn)

# Apply dynamic quantization.
apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)

# Compile.
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]

Results from other pipelines ๐ŸŒ‹

SSD-1B
SD v1-5
Pixart-Alpha

diffusion-fast's People

Contributors

hdcharles avatar sayakpaul avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

diffusion-fast's Issues

lora support for optimization

I was using the torch compile optimization for speeding the inference time

Here I am using the dreambooth lora model which was trained on juggernut

when making the inference its not compiling

pipe.load_lora_weights(prj_path, weight_name="pytorch_lora_weights.safetensors")

is there any way so that I can able to use this optimzation for dreambooth lora models

packages I am using

Package Version


absl-py 2.1.0
accelerate 0.26.1
aiofiles 23.2.1
aiohttp 3.9.3
aiosignal 1.3.1
albumentations 1.3.1
alembic 1.13.1
altair 5.2.0
annotated-types 0.6.0
anyio 3.7.1
arrow 1.3.0
async-timeout 4.0.3
attrs 23.2.0
Authlib 1.3.0
autotrain-advanced 0.6.92
bitsandbytes 0.42.0
Brotli 1.1.0
cachetools 5.3.2
certifi 2023.11.17
cffi 1.16.0
charset-normalizer 3.3.2
click 8.1.7
cmaes 0.10.0
cmake 3.28.3
codecarbon 2.2.3
colorlog 6.8.2
contourpy 1.1.1
cryptography 42.0.3
cycler 0.12.1
datasets 2.14.7
diffusers 0.21.4
dill 0.3.8
docstring-parser 0.15
einops 0.6.1
evaluate 0.3.0
exceptiongroup 1.2.0
fastapi 0.104.1
ffmpy 0.3.1
filelock 3.13.1
fonttools 4.47.2
frozenlist 1.4.1
fsspec 2023.10.0
fuzzywuzzy 0.18.0
google-auth 2.27.0
google-auth-oauthlib 1.0.0
GPUtil 1.4.0
gradio 3.41.0
gradio-client 0.5.0
greenlet 3.0.3
grpcio 1.60.0
h11 0.14.0
hf-transfer 0.1.5
httpcore 1.0.2
httpx 0.26.0
huggingface-hub 0.20.3
idna 3.6
imageio 2.33.1
importlib-metadata 7.0.1
importlib-resources 6.1.1
inflate64 1.0.0
install 1.3.5
invisible-watermark 0.2.0
ipadic 1.0.0
itsdangerous 2.1.2
Jinja2 3.1.3
jiwer 3.0.2
joblib 1.3.1
jsonschema 4.21.1
jsonschema-specifications 2023.12.1
kiwisolver 1.4.5
lazy-loader 0.3
loguru 0.7.0
Mako 1.3.2
Markdown 3.5.2
markdown-it-py 3.0.0
MarkupSafe 2.1.4
matplotlib 3.7.4
mdurl 0.1.2
mpmath 1.3.0
multidict 6.0.4
multiprocess 0.70.16
multivolumefile 0.2.3
networkx 3.1
nltk 3.8.1
numpy 1.24.4
nvidia-cublas-cu11 11.11.3.6
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu11 11.8.87
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu11 11.8.89
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu11 11.8.89
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu11 8.7.0.84
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu11 10.9.0.58
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu11 10.3.0.86
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu11 11.4.1.48
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu11 11.7.5.86
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu11 2.19.3
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.3.101
nvidia-nvtx-cu11 11.8.86
nvidia-nvtx-cu12 12.1.105
oauthlib 3.2.2
opencv-python 4.9.0.80
opencv-python-headless 4.9.0.80
optuna 3.3.0
orjson 3.9.12
packaging 23.1
pandas 2.0.3
peft 0.8.2
Pillow 10.0.0
pip 20.0.2
pkg-resources 0.0.0
pkgutil-resolve-name 1.3.10
protobuf 4.23.4
psutil 5.9.8
py-cpuinfo 9.0.0
py7zr 0.20.6
pyarrow 15.0.0
pyarrow-hotfix 0.6
pyasn1 0.5.1
pyasn1-modules 0.3.0
pybcj 1.0.2
pycparser 2.21
pycryptodomex 3.20.0
pydantic 2.4.2
pydantic-core 2.10.1
pydub 0.25.1
pygments 2.17.2
pyngrok 7.0.3
pynvml 11.5.0
pyparsing 3.1.1
pyppmd 1.0.0
python-dateutil 2.8.2
python-dotenv 1.0.1
python-multipart 0.0.6
pytorch-triton 3.0.0+901819d2b6
pytz 2023.4
PyWavelets 1.4.1
PyYAML 6.0.1
pyzstd 0.15.9
qudida 0.0.4
rapidfuzz 2.13.7
referencing 0.33.0
regex 2023.12.25
requests 2.31.0
requests-oauthlib 1.3.1
responses 0.18.0
rich 13.7.0
rouge-score 0.1.2
rpds-py 0.17.1
rsa 4.9
sacremoses 0.0.53
safetensors 0.4.2
scikit-image 0.21.0
scikit-learn 1.3.0
scipy 1.10.1
semantic-version 2.10.0
sentencepiece 0.1.99
setuptools 44.0.0
shtab 1.6.5
six 1.16.0
sniffio 1.3.0
SQLAlchemy 2.0.25
starlette 0.27.0
sympy 1.12
tensorboard 2.14.0
tensorboard-data-server 0.7.2
texttable 1.7.0
threadpoolctl 3.2.0
tifffile 2023.7.10
tiktoken 0.5.1
tokenizers 0.15.1
toolz 0.12.1
torch 2.3.0.dev20240221+cu118
torchaudio 2.2.0+cu118
torchtriton 2.0.0+f16138d447
torchvision 0.17.0
tqdm 4.65.0
transformers 4.37.0
triton 2.2.0
trl 0.7.11
types-python-dateutil 2.8.19.20240106
typing-extensions 4.9.0
tyro 0.7.0
tzdata 2023.4
urllib3 2.2.0
uvicorn 0.22.0
websockets 11.0.3
Werkzeug 2.3.6
wheel 0.34.2
xformers 0.0.24
xgboost 1.7.6
xxhash 3.4.1
yarl 1.9.4
zipp 3.17.0

image

quantization issue for weight only int8 with fp16

bf16 or fp32 works. I tested the change_linear_weights_to_int8_woqtensors, change_linear_weights_to_int8_dqtensors, change_linear_weights_to_int4_woqtensors (they're generally safer than the module swap methods since they dispatch on the funciton being used)

all 3 works reasonably well, int8weight only quant doesn't work for fp16 though since if both the weight and activation are fp16, you overflow the fp16 range before the rescale.

see:

https://github.com/sayakpaul/sdxl-fast/pull/2/files

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.