Giter Club home page Giter Club logo

foundationvision / var Goto Github PK

View Code? Open in Web Editor NEW
3.2K 89.0 252.0 97 KB

[GPT beats diffusion🔥] [scaling laws in visual generation📈] Official impl. of "Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction". An *ultra-simple, user-friendly yet state-of-the-art* codebase for autoregressive image generation!

License: MIT License

Python 96.62% Jupyter Notebook 3.38%
auto-regressive-model diffusion-models image-generation transformers autoregressive-models generative-ai generative-model gpt gpt-2 large-language-models

var's Introduction

VAR: a new visual generation method elevates GPT-style models beyond diffusion🚀 & Scaling laws observed📈

demo platform  arXiv  huggingface weights  SOTA

Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction


🕹️ Try and Play with VAR!

We provide a demo website for you to play with VAR models and generate images interactively. Enjoy the fun of visual autoregressive modeling!

We also provide demo_sample.ipynb for you to see more technical details about VAR.

What's New?

🔥 Introducing VAR: a new paradigm in autoregressive visual generation✨:

Visual Autoregressive Modeling (VAR) redefines the autoregressive learning on images as coarse-to-fine "next-scale prediction" or "next-resolution prediction", diverging from the standard raster-scan "next-token prediction".

🔥 For the first time, GPT-style autoregressive models surpass diffusion models🚀:

🔥 Discovering power-law Scaling Laws in VAR transformers📈:

🔥 Zero-shot generalizability🛠️:

For a deep dive into our analyses, discussions, and evaluations, check out our paper.

VAR zoo

We provide VAR models for you to play with, which are on or can be downloaded from the following links:

model reso. FID rel. cost #params HF weights🤗
VAR-d16 256 3.55 0.4 310M var_d16.pth
VAR-d20 256 2.95 0.5 600M var_d20.pth
VAR-d24 256 2.33 0.6 1.0B var_d24.pth
VAR-d30 256 1.97 1 2.0B var_d30.pth
VAR-d30-re 256 1.80 1 2.0B var_d30.pth

You can load these models to generate images via the codes in demo_sample.ipynb. Note: you need to download vae_ch160v4096z32.pth first.

Installation

  1. Install torch>=2.0.0.

  2. Install other pip packages via pip3 install -r requirements.txt.

  3. Prepare the ImageNet dataset

    assume the ImageNet is in `/path/to/imagenet`. It should be like this:
    /path/to/imagenet/:
        train/:
            n01440764: 
                many_images.JPEG ...
            n01443537:
                many_images.JPEG ...
        val/:
            n01440764:
                ILSVRC2012_val_00000293.JPEG ...
            n01443537:
                ILSVRC2012_val_00000236.JPEG ...
    

    NOTE: The arg --data_path=/path/to/imagenet should be passed to the training script.

  4. (Optional) install and compile flash-attn and xformers for faster attention computation. Our code will automatically use them if installed. See models/basic_var.py#L15-L30.

Training Scripts

To train VAR-{d16, d20, d24, d30, d36-s} on ImageNet 256x256 or 512x512, you can run the following command:

# d16, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
  --depth=16 --bs=768 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1
# d20, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
  --depth=20 --bs=768 --ep=250 --fp16=1 --alng=1e-3 --wpe=0.1
# d24, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
  --depth=24 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-4 --wpe=0.01
# d30, 256x256
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
  --depth=30 --bs=1024 --ep=350 --tblr=8e-5 --fp16=1 --alng=1e-5 --wpe=0.01 --twde=0.08
# d36-s, 512x512 (-s means saln=1, shared AdaLN)
torchrun --nproc_per_node=8 --nnodes=... --node_rank=... --master_addr=... --master_port=... train.py \
  --depth=36 --saln=1 --pn=512 --bs=768 --ep=350 --tblr=8e-5 --fp16=1 --alng=5e-6 --wpe=0.01 --twde=0.08

A folder named local_output will be created to save the checkpoints and logs. You can monitor the training process by checking the logs in local_output/log.txt and local_output/stdout.txt, or using tensorboard --logdir=local_output/.

If your experiment is interrupted, just rerun the command, and the training will automatically resume from the last checkpoint in local_output/ckpt*.pth (see utils/misc.py#L344-L357).

Sampling & Zero-shot Inference

For FID evaluation, use var.autoregressive_infer_cfg(..., cfg=1.5, top_p=0.96, top_k=900, more_smooth=False) to sample 50,000 images (50 per class) and save them as PNG (not JPEG) files in a folder. Pack them into a .npz file via create_npz_from_sample_folder(sample_folder) in utils/misc.py#L344. Then use the OpenAI's FID evaluation toolkit and reference ground truth npz file of 256x256 or 512x512 to evaluate FID, IS, precision, and recall.

Note a relatively small cfg=1.5 is used for trade-off between image quality and diversity. You can adjust it to cfg=5.0, or sample with autoregressive_infer_cfg(..., more_smooth=True) for better visual quality. We'll provide the sampling script later.

License

This project is licensed under the MIT License - see the LICENSE file for details.

Citation

If our work assists your research, feel free to give us a star ⭐ or cite us using:

@Article{VAR,
      title={Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction}, 
      author={Keyu Tian and Yi Jiang and Zehuan Yuan and Bingyue Peng and Liwei Wang},
      year={2024},
      eprint={2404.02905},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

var's People

Contributors

ifighting avatar keyu-tian avatar nielsrogge avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

var's Issues

Training and Inference Scripts

Hi, amazing work indeed !

I wonder, can you please release the training and inference scripts to be able to exactly reproduce your ImageNet 256x256 (and 512X512) results ?

In addition, can you please provide information about the number of nodes/GPUs used for each of the ImageNet256 and ImageNet512 experiments ?

about the training code

thanks for your great work.
I see the training code is empty. when it will be release? Thank you!

training code

Hi there,

thanks for your amazing work! I noticed the recent updates covered the training code. Will there also be updates about training scripts and data preparation? I'm particularly interested in training the VQVAE part.

About the input resolution scale

From the method mentioned in the paper, if the output resolution are huge, such as 1024x2048, the actually generation time would be much more large than diffusion model.

So, in large image generation, what is the strength actually of this method?

About testing Zero-shot generalizability

Thank you for sharing the code!

Just noticed the empty script in train.py, would like to know by when I can get an access to the training code. Also, is there any plan of releasing code regarding zero-shot generalization such as Class-cond editing?

Thanks in advance!

About progressive training

Great work, author!
Since I found no training details in the paper (don't know if I've missed something), I wonder if there is any ablation study or difference on the performance of the progressive training compared to the naive one? Or it is just for efficient reasons like you mentioned in #29

RuntimeError: Error(s) in loading state_dict for VAR in demo_sample.ipynb

var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)

RuntimeError: Error(s) in loading state_dict for VAR:
Unexpected key(s) in state_dict: "blocks.0.attn.scale_mul_1H11", "blocks.1.attn.scale_mul_1H11", "blocks.2.attn.scale_mul_1H11", "blocks.3.attn.scale_mul_1H11", "blocks.4.attn.scale_mul_1H11", "blocks.5.attn.scale_mul_1H11", "blocks.6.attn.scale_mul_1H11", "blocks.7.attn.scale_mul_1H11", "blocks.8.attn.scale_mul_1H11", "blocks.9.attn.scale_mul_1H11", "blocks.10.attn.scale_mul_1H11", "blocks.11.attn.scale_mul_1H11", "blocks.12.attn.scale_mul_1H11", "blocks.13.attn.scale_mul_1H11", "blocks.14.attn.scale_mul_1H11", "blocks.15.attn.scale_mul_1H11".

How do you make Transformer generate tokens in parallel?

It is mentioned multiple times in the paper that all tokens from the same scale r are generated in parallel. Did I overlook or there is actually little description about how to generate tokens in parallel in the VAR transformer?

Question about the shared codebook

Hi, thanks for your wonderful work! I'm studying your work with big interest!

Is there any reason to use shared code book for multiple scales?
Intuitively, it seems there might be a possibility of performance gain via seperate codebook. (distinguishing the role of each codebook)
I wonder if you've tried unshared codebook.

Again, thank you!

fid

where is your eval code? i want to calculate fid, thank you!

AR Time Complexity?

Great results! :) I have a question though about complexity

In section 3.1, you provide three issues with autoregressive transformer models. The third says that an autoregressive transformer generating n^2 tokens will take n^6 time. In your Appendix you show this by assuming that for each token i the attentions with the previous (i-1)^2 tokens need to be computed, which takes O(i^2), so the total time is the sum of i=1 to n^2 of i^2 which is O(n^6).

However, in practice we cache the intermediate representations for each token during autoregressive sampling. See here for an explanation. This means for each i we reuse the attentions that were computed for the preceding (i-1) tokens, and so only have to compute attentions with the ith token and each of the previous tokens, which takes O(i) time. Therefore the sum is i=1 to n^2 of i which is O(i^4), the same as the result that you report for your method.

Please let me know if for some reason the caching trick doesn't apply to the AR methods you're comparing against, or if your method can also benefit from caching in that way to speed up generation!

class VQVAE forward function error

CN:

VQVAE forward存在code 错误. 我发现 VectorQuantizer2.forward 返回三个参数. 而在

h_BChw, usages, vq_loss, mean_entropy_loss = self.quantize(self.quant_conv(self.encoder(inp)), ret_usages=ret_usages)

接收了VectorQuantizer2.forward 返回的4个参数.

EN:

VQVAE forward has a code error. I found that VectorQuantizer2.forward returns three parameters. And at

h_BChw, usages, vq_loss, mean_entropy_loss = self.quantize(self.quant_conv(self.encoder(inp)), ret_usages=ret_usages)

receives four arguments from VectorQuantizer2.forward.

what is your position embeding ? is 2d RoPE good choice ?

兄弟试过没有 2d RoPE 会不会进一步 提高模型性能? 我没那么多显卡还没试

Bro, have you tried if using 2D RoPE would further improve model performance? I don't have many GPUs, so I haven't tried it yet.

Actual Training Loss Curve

In the VAR model, a standard cross-entropy loss is employed; Yet Figure 5 illustrates the scaling law on a modified loss function. Could you provide the precise equation detailing how the actual training and test loss are converted to this reduced form? Additionally, is there a training log available that records the actual cross-entropy loss values over the course of the training? Access to this information would greatly enhance readers' comprehension of the training process dynamics for the VAR model.

Thanks a lot!

Question about the unconditional generation

Hello Firstly, thanks for your great work and share the code including weights!

I have quite simple question about unconditional generation.
Referring to figure 4, I understood that the model can be passible to create image without giving class information in sos token.
image

When referring to the demo_sample.ipynb file (in autoregressive_infer_cfg funciton), it seems that if label_B is None, it is not unconditional generation because it seems to randomly select from 1000 classes by sampling.
Is the without condition information in a paper the same as if label_B is None?

Also,
Referring to the class_emb, the size of the embedding is 1001, which is 1 more than the imagenet class,
but when I generate an image using the 1000th class token, it seems like I get a pretty random image.
Is generating an image with this class index the same as without condition information?

Sincerely,
Jongmin

Training resources

Hello, thanks for your excellent work!
Can you share the computation resources for training VAR? Basically such as the number of GPUs and training days per epoch.
This would be of great help!

请问VQVAE(stage1)阶段是怎样使用多级VectorQuantizer的?

感谢作者开源的代码,我发现paper的algorithm 2里写了是把z_k随着分辨率升高,也一直插值,然后一起送入decoder里,但是我看代码,decoder好像是直接那z_k使用,并没有相关的插值啊?
`

    class Decoder(nn.Module):
    
        def forward(self, z):
            # z to block_in
            # middle
            h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))
            
    
            for i_level in reversed(range(self.num_resolutions)):
                for i_block in range(self.num_res_blocks + 1):
                    h = self.up[i_level].block[i_block](h)
                    if len(self.up[i_level].attn) > 0:
                        h = self.up[i_level].attn[i_block](h)
                if i_level != 0:
                    h = self.up[i_level].upsample(h)
            
    
            h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
            return h

`

然后相应的upsample函数也是直接的上采样:
`

    class Upsample2x(nn.Module):
        def __init__(self, in_channels):
            super().__init__()
            self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        
        def forward(self, x):
    return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))

`

请问下algorithm2是怎么体现的呀?感谢感谢~~

flash-atten相关问题

你好,有个flash-atten的问题想请教下,当我想使能flash-attn时,我发现以下图1的逻辑根本走不进去,为此我打印了self.using_flash、attn_bias、qkv.dtype,最后发现attn_bias一直不是None(图2)
图1:
image
图2:
image

于是我将代码修改成以下逻辑:
using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32
修改为
using_flash = self.using_flash and qkv.dtype != torch.float32

assert attn_bias is None and qkv.dtype != torch.float32
修改为
assert qkv.dtype != torch.float32
但最后报了如图3的错误
图3:
image

于是我继续打印输入的q、k、v的dtype(如图4)
图4:
image
最后在代码中添加以下逻辑后功能才OK
image
请问这是已知bug吗,麻烦请检查下呢,或者是我哪里操作不对吗,请指导下,最后是我的运行命令
torchrun --nproc_per_node=8 --nnodes=8 --node_rank=1 train.py --depth=16 --bs=384 --ep=200 --fp16=1 --alng=1e-3 --wpe=0.1 --afuse=False

FID on Class-Conditioned Evaluation & Normalized Attn in Depth 16/30

Hi Keyu,

Thanks for sharing this interesting work. I have two questions and hope to get some insights from you.

(1) I am trying to evaluate the released ckpt on the ImageNet eval set. The reproduced FID is about ~7 for the depth-16 model. I consider there should be a config difference between our testing (we used the configs in the released demo and torchmetrics to evaluate) and the reported experiment. Would you please let us know the specific parameters setting for sampling during the evaluation? Thanks.

(2) I noticed that in a recent update, the released ckpts were replaced and so were the configs of depth 16/30 var configs. I wonder what is the influence of normalized attention and what about the performance improvement based on this modification. Thanks.

Best,
A loyal reader of your paper

Error reported during training phase

Hello, would you like to ask how to solve the following error during training in version torch=2.0.1?

[04-22 21:18:11] (ata/private/VAR/train.py, line 41)=> [build PT data] ...

[04-22 21:18:11] (rivate/VAR/utils/data.py, line 34)=> [Dataset] len(train_set)=17028, len(val_set)=3552, num_classes=1000
[04-22 21:18:11] (rivate/VAR/utils/data.py, line 48)=> Transform [train] =
[04-22 21:18:11] (rivate/VAR/utils/data.py, line 51)=> Resize(size=288, interpolation=lanczos, max_size=None, antialias=warn)
[04-22 21:18:11] (rivate/VAR/utils/data.py, line 51)=> RandomCrop(size=(256, 256), padding=None)
[04-22 21:18:11] (rivate/VAR/utils/data.py, line 51)=> ToTensor()
[04-22 21:18:11] (rivate/VAR/utils/data.py, line 51)=> <function normalize_01_into_pm1 at 0x7f51d2e66a70>
[04-22 21:18:11] (rivate/VAR/utils/data.py, line 54)=> ---------------------------

[04-22 21:18:11] (rivate/VAR/utils/data.py, line 48)=> Transform [val] =
[04-22 21:18:11] (rivate/VAR/utils/data.py, line 51)=> Resize(size=288, interpolation=lanczos, max_size=None, antialias=warn)
[04-22 21:18:11] (rivate/VAR/utils/data.py, line 51)=> CenterCrop(size=(256, 256))
[04-22 21:18:11] (rivate/VAR/utils/data.py, line 51)=> ToTensor()
[04-22 21:18:11] (rivate/VAR/utils/data.py, line 51)=> <function normalize_01_into_pm1 at 0x7f51d2e66a70>
[04-22 21:18:11] (rivate/VAR/utils/data.py, line 54)=> ---------------------------

[04-22 21:18:11] (ata/private/VAR/train.py, line 64)=> [auto_resume] no ckpt found @ /opt/data/private/VAR/local_output/ar-ckpt*.pth
[04-22 21:18:11] (ata/private/VAR/train.py, line 64)=> [auto_resume quit]
[04-22 21:18:11] (ata/private/VAR/train.py, line 65)=> [dataloader multi processing] ... dataloader multi processing finished! (0.00s)
[04-22 21:18:11] (ata/private/VAR/train.py, line 71)=> [dataloader] gbs=1, lbs=1, iters_train=17028, types(tr, va)=('DatasetFolder', 'DatasetFolder')
[04-22 21:18:11] (rivate/VAR/models/var.py, line 98)=>
[constructor] ==== flash_if_available=True (0/16), fused_if_available=True (fusing_add_ln=0/16, fusing_mlp=0/16) ====
[VAR config ] embed_dim=1024, num_heads=16, depth=16, mlp_ratio=4.0
[drop ratios ] drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0666667 (tensor([0.0000, 0.0044, 0.0089, 0.0133, 0.0178, 0.0222, 0.0267, 0.0311, 0.0356,
0.0400, 0.0444, 0.0489, 0.0533, 0.0578, 0.0622, 0.0667]))

[04-22 21:18:12] (rivate/VAR/models/var.py, line 239)=> [init_weights] VAR with init_std=0.0180422
Traceback (most recent call last):
File "/opt/data/private/VAR/train.py", line 334, in
try: main_training()
File "/opt/data/private/VAR/train.py", line 180, in main_training
) = build_everything(args)
File "/opt/data/private/VAR/train.py", line 98, in build_everything
vae_local.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
File "/opt/data/private/VAR/models/vqvae.py", line 95, in load_state_dict
return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
TypeError: Module.load_state_dict() got an unexpected keyword argument 'assign'

关于并行解码的疑问(Questions About Parallel Decoding)

您好, 感谢您杰出的工作. 我对于并行解码和用交叉熵损失仍有疑问, 这假设了同一分辨率下每个token是独立的, 这是否能或为什么能对一个分辨率下的token良好建模. 假设一个场景, 如果总共只有两个样本: X0: (1, 2), X1: (2, 1), 如果用长度为2的查询向量并行解码, 如果用独立的交叉熵进行训练, 则模型学到的分布就是 查询向量->[0.5, 0.5], 而这样在后面采样时会等概率地采样到(1, 1), (1, 2), (2, 1), (2, 2), 这样就和真实分布不一样了. 不知是否我哪里理解的有问题, 还请不吝赐教.

(Thank you for your outstanding work. I still have some doubts regarding parallel decoding and the use of cross-entropy loss. This assumes that each token is independent at the same resolution, which raises the question of whether and why this approach can effectively model tokens at a given resolution.

Consider a scenario where there are only two samples: X0: (1, 2) and X1: (2, 1). If we use a query vector of length 2 for parallel decoding and train with independent cross-entropy, the distribution the model learns would be a uniform one, mapping the query vector to ->[0.5, 0.5]. This results in an equal probability of sampling (1, 1), (1, 2), (2, 1), and (2, 2) during subsequent sampling, which deviates from the true distribution. I am unsure if there's a misunderstanding on my part, and I would greatly appreciate your guidance on this matter.)

How to set progressive training in traing the VAR?

Great job indeed!

Currently, I am attempting to train the VAR by utilizing the script and code you provided. Upon inspection of the training code and args list, it appears that there are three args, namely pg, pg0, and pgwp, designated for progressive training. I am curious about the hyperparameters you configured for your training, as well as the specific model chosen for progressive training.

Furthermore, I am interested in understanding why the VAR class's self.prog_si in var.py is initially set to -1 and remains unchanged throughout. It seems that neither train.py nor trainer.py reset the prog_si attribute of the VAR class.

can it do super resolution?

Can VAR do super resolution like GigaGan super resolution for example. Gigagan is the most impressive super resolution algorithm till now.
And if yes would you be able to add support for it Later next month or so?

Resource consumption

Great work! I am wondering how much GPU time does this work require for training across different models? I can't find descriptions in paper.

The performance of VAR Tokenizer

What is the performance of VAR tokenizer? As it is trained on OpenImages while some other VQGAN tokenizers are trained on ImageNet only. I wonder the gain of performance brought by the pre-trained data.

There are errors in the model folders .py files

hmm, there are new errors in the model folders .py files saying "Import ".###" could not be resolvedPyright(reportMissingImports)" to .py files that are ALL already in the VAR respiratory, even after restarting runtime.

Keep getting this error in demo_sample.ipynb

in <cell line: 10>()
8 setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
9 setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
---> 10 from models import VQVAE, build_var
*** 11 ***
12 MODEL_DEPTH = 16 # TODO: =====> please specify MODEL_DEPTH <=====

ModuleNotFoundError: No module named 'models'

but when downloading "Models" Module...

Using cached models-0.9.3.tar.gz (16 kB)
error: subprocess-exited-with-error

× python setup.py egg_info did not run successfully.
│ exit code: 1
╰─> See above for output.

note: This error originates from a subprocess, and is likely not a problem with pip.
Preparing metadata (setup.py) ... error
error: metadata-generation-failed

× Encountered error while generating package metadata.
╰─> See above for output.

note: This is an issue with the package mentioned above, not pip.
hint: See above for details.

Abnormal sample results with `demo_sample.ipynb`

Thanks for sharing this nice work!

I tried to sample with demo_sample.ipynb and var_d16.pth without any changes, but I got these abnormal results😢:
output_image

The following is my environment:

python                    3.10.14
torch                     2.0.1                    
torchvision               0.15.2                              
transformers              4.40.1               
triton                    2.0.0     
pillow                    10.3.0 
pytz                      2024.1 
typed-argument-parser     1.10.0             

Does this result from the precision problem or something else? I would appreciate it if you could help with this problem.🥺

Implications of Classifier-Free Guidance in Auto-regressive Models

Hi, thank you for the insightful work!

I have some concerns regarding the classifier-free guidance (CFG) in auto-regressive models.

CFG in this work is implemented as follows:

VAR/models/var.py

Lines 191 to 192 in 1ae5177

t = cfg * ratio
logits_BlV = (1+t) * logits_BlV[:B] - t * logits_BlV[B:]

However, it's important to note that CFG in auto-regressive models differs fundamentally from that in diffusion models (as outlined in Section 4 of this blog). In essence, the guidance in diffusion models is not theoretically applicable to auto-regressive models.

I am curious if this difference yields any notable empirical results. Have you conducted any quantitative or qualitative studies on the impact of CFG on this auto-regressive model? I would greatly appreciate any insights or empirical findings you could share on this subject.

Dtype error with flash-attention

This error occurs when sampling with pretrained model:
"/xxx/VAR/models/basic_var.py", line 113, in forward
oup = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C)
RuntimeError: FlashAttention only support fp16 and bf16 data type.

The problem comes from that while qkv is initially fp16, the scale_mul in line 101 of basic_var.py is fp32, which makes q and k become fp32.

update: F.normalize(q, dim=-1) also changes the dtype of q to fp32.

The 512*512 model checkpiont?

Hello, It is a nice work with very insteresting and intuitive idea.
It seems that there only has ckp with 256, I wonder would you like to share the ckp with 512?

FID misalignment

Great works and thanks for publishing the code!

I encountered some problems when calculating the FID.

Originally I use my own FID calculation code to calculate the FID between the images from ImageNet validation dataset and their VQVAE autoencoding reconstruction, and the result is 0.92, which is make sense. However, when I use the same code to calculate the FID between the images from ImageNet validation dataset and your d16 conditional generated images, the result is 19.13.

I also use your method to calculate the FID: we use the following code to create a npz file and run python evaluator.py VIRTUAL_imagenet256_labeled.npz tmp.npz. The result is 18.25. Do your have any idea where I make mistakes.

for batch_id, batch in enumerate(imagenet_val_dataloader):
    label = batch["label"]
    gen_images = var.autoregressive_infer_cfg(
                B=label.shape[0],
                label_B=label.to(device),
                cfg=1.5,
                top_k=900,
                top_p=0.96,
                more_smooth=False
            )
    gen_images = gen_images.mul(255).add(0.5).clamp(0, 255).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy()
    for i in range(gen_images.shape[0]):
        Image.fromarray(gen_images[i]).save(os.path.join("./tmp/", str(batch_id) + "_" + str(i) +  ".png"))

create_npz_from_sample_folder("./tmp")

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.