Giter Club home page Giter Club logo

vector-quantize-pytorch's Introduction

Vector Quantization - Pytorch

A vector quantization library originally transcribed from Deepmind's tensorflow implementation, made conveniently into a package. It uses exponential moving averages to update the dictionary.

VQ has been successfully used by Deepmind and OpenAI for high quality generation of images (VQ-VAE-2) and music (Jukebox).

Install

$ pip install vector-quantize-pytorch

Usage

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 512,     # codebook size
    decay = 0.8,             # the exponential moving average decay, lower means the dictionary will change faster
    commitment_weight = 1.   # the weight on the commitment loss
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)

Residual VQ

This paper proposes to use multiple vector quantizers to recursively quantize the residuals of the waveform. You can use this with the ResidualVQ class and one extra initialization parameter.

import torch
from vector_quantize_pytorch import ResidualVQ

residual_vq = ResidualVQ(
    dim = 256,
    num_quantizers = 8,      # specify number of quantizers
    codebook_size = 1024,    # codebook size
)

x = torch.randn(1, 1024, 256)

quantized, indices, commit_loss = residual_vq(x)
print(quantized.shape, indices.shape, commit_loss.shape)
# (1, 1024, 256), (1, 1024, 8), (1, 8)

# if you need all the codes across the quantization layers, just pass return_all_codes = True

quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True)

# (8, 1, 1024, 256)

Furthermore, this paper uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes.

They make two modifications. The first is to share the codebook across all quantizers. The second is to stochastically sample the codes rather than always taking the closest match. You can use both of these features with two extra keyword arguments.

import torch
from vector_quantize_pytorch import ResidualVQ

residual_vq = ResidualVQ(
    dim = 256,
    num_quantizers = 8,
    codebook_size = 1024,
    stochastic_sample_codes = True,
    sample_codebook_temp = 0.1,         # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic
    shared_codebook = True              # whether to share the codebooks for all quantizers or not
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)

# (1, 1024, 256), (1, 1024, 8), (1, 8)

A recent paper further proposes to do residual VQ on groups of the feature dimension, showing equivalent results to Encodec while using far fewer codebooks. You can use it by importing GroupedResidualVQ

import torch
from vector_quantize_pytorch import GroupedResidualVQ

residual_vq = GroupedResidualVQ(
    dim = 256,
    num_quantizers = 8,      # specify number of quantizers
    groups = 2,
    codebook_size = 1024,    # codebook size
)

x = torch.randn(1, 1024, 256)

quantized, indices, commit_loss = residual_vq(x)

# (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8)

Initialization

The SoundStream paper proposes that the codebook should be initialized by the kmeans centroids of the first batch. You can easily turn on this feature with one flag kmeans_init = True, for either VectorQuantize or ResidualVQ class

import torch
from vector_quantize_pytorch import ResidualVQ

residual_vq = ResidualVQ(
    dim = 256,
    codebook_size = 256,
    num_quantizers = 4,
    kmeans_init = True,   # set to True
    kmeans_iters = 10     # number of kmeans iterations to calculate the centroids for the codebook on init
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = residual_vq(x)

# (1, 1024, 256), (1, 1024, 4), (1, 4)

Increasing codebook usage

This repository will contain a few techniques from various papers to combat "dead" codebook entries, which is a common problem when using vector quantizers.

Lower codebook dimension

The Improved VQGAN paper proposes to have the codebook kept in a lower dimension. The encoder values are projected down before being projected back to high dimensional after quantization. You can set this with the codebook_dim hyperparameter.

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 256,
    codebook_dim = 16      # paper proposes setting this to 32 or as low as 8 to increase codebook usage
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)

# (1, 1024, 256), (1, 1024), (1,)

Cosine similarity

The Improved VQGAN paper also proposes to l2 normalize the codes and the encoded vectors, which boils down to using cosine similarity for the distance. They claim enforcing the vectors on a sphere leads to improvements in code usage and downstream reconstruction. You can turn this on by setting use_cosine_sim = True

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 256,
    use_cosine_sim = True   # set this to True
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)

# (1, 1024, 256), (1, 1024), (1,)

Expiring stale codes

Finally, the SoundStream paper has a scheme where they replace codes that have hits below a certain threshold with randomly selected vector from the current batch. You can set this threshold with threshold_ema_dead_code keyword.

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 512,
    threshold_ema_dead_code = 2  # should actively replace any codes that have an exponential moving average cluster size less than 2
)

x = torch.randn(1, 1024, 256)
quantized, indices, commit_loss = vq(x)

# (1, 1024, 256), (1, 1024), (1,)

Orthogonal regularization loss

VQ-VAE / VQ-GAN is quickly gaining popularity. A recent paper proposes that when using vector quantization on images, enforcing the codebook to be orthogonal leads to translation equivariance of the discretized codes, leading to large improvements in downstream text to image generation tasks.

You can use this feature by simply setting the orthogonal_reg_weight to be greater than 0, in which case the orthogonal regularization will be added to the auxiliary loss outputted by the module.

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 256,
    accept_image_fmap = True,                   # set this true to be able to pass in an image feature map
    orthogonal_reg_weight = 10,                 # in paper, they recommended a value of 10
    orthogonal_reg_max_codes = 128,             # this would randomly sample from the codebook for the orthogonal regularization loss, for limiting memory usage
    orthogonal_reg_active_codes_only = False    # set this to True if you have a very large codebook, and would only like to enforce the loss on the activated codes per batch
)

img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)

# loss now contains the orthogonal regularization loss with the weight as assigned

Multi-headed VQ

There has been a number of papers that proposes variants of discrete latent representations with a multi-headed approach (multiple codes per feature). I have decided to offer one variant where the same codebook is used to vector quantize across the input dimension head times.

You can also use a more proven approach (memcodes) from NWT paper

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_dim = 32,                  # a number of papers have shown smaller codebook dimension to be acceptable
    heads = 8,                          # number of heads to vector quantize, codebook shared across all heads
    separate_codebook_per_head = True,  # whether to have a separate codebook per head. False would mean 1 shared codebook
    codebook_size = 8196,
    accept_image_fmap = True
)

img_fmap = torch.randn(1, 256, 32, 32)
quantized, indices, loss = vq(img_fmap)

# (1, 256, 32, 32), (1, 32, 32, 8), (1,)

Random Projection Quantizer

This paper first proposed to use a random projection quantizer for masked speech modeling, where signals are projected with a randomly initialized matrix and then matched with a random initialized codebook. One therefore does not need to learn the quantizer. This technique was used by Google's Universal Speech Model to achieve SOTA for speech-to-text modeling.

USM further proposes to use multiple codebook, and the masked speech modeling with a multi-softmax objective. You can do this easily by setting num_codebooks to be greater than 1

import torch
from vector_quantize_pytorch import RandomProjectionQuantizer

quantizer = RandomProjectionQuantizer(
    dim = 512,               # input dimensions
    num_codebooks = 16,      # in USM, they used up to 16 for 5% gain
    codebook_dim = 256,      # codebook dimension
    codebook_size = 1024     # codebook size
)

x = torch.randn(1, 1024, 512)
indices = quantizer(x)

# (1, 1024, 16)

This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting sync_codebook = True | False

Finite Scalar Quantization

VQ FSQ
Quantization argmin_c || z-c || round(f(z))
Gradients Straight Through Estimation (STE) STE
Auxiliary Losses Commitment, codebook, entropy loss, ... N/A
Tricks EMA on codebook, codebook splitting, projections, ... N/A
Parameters Codebook N/A

This work out of Google Deepmind aims to vastly simplify the way vector quantization is done for generative modeling, removing the need for commitment losses, EMA updating of the codebook, as well as tackle the issues with codebook collapse or insufficient utilization. They simply round each scalar into discrete levels with straight through gradients; the codes become uniform points in a hypercube.

Thanks goes out to @sekstini for porting over this implementation in record time!

import torch
from vector_quantize_pytorch import FSQ

quantizer = FSQ(
    levels = [8, 5, 5, 5]
)

x = torch.randn(1, 1024, 4) # 4 since there are 4 levels
xhat, indices = quantizer(x)

# (1, 1024, 4), (1, 1024)

assert torch.all(xhat == quantizer.indices_to_codes(indices))

An improvised Residual FSQ, for an attempt to improve audio encoding.

Credit goes to @sekstini for originally incepting the idea here

import torch
from vector_quantize_pytorch import ResidualFSQ

residual_fsq = ResidualFSQ(
    dim = 256,
    levels = [8, 5, 5, 3],
    num_quantizers = 8
)

x = torch.randn(1, 1024, 256)

residual_fsq.eval()

quantized, indices = residual_fsq(x)

# (1, 1024, 256), (1, 1024, 8)

quantized_out = residual_fsq.get_output_from_indices(indices)

# (1, 1024, 256)

assert torch.all(quantized == quantized_out)

Lookup Free Quantization

The research team behind MagViT has released new SOTA results for generative video modeling. A core change between v1 and v2 include a new type of quantization, look-up free quantization (LFQ), which eliminates the codebook and embedding lookup entirely.

This paper presents a simple LFQ quantizer of using independent binary latents. Other implementations of LFQ exist. However, the team shows that MAGVIT-v2 with LFQ significantly improves on the ImageNet benchmark. The differences between LFQ and 2-level FSQ includes entropy regularizations as well as maintained commitment loss.

Developing a more advanced method of LFQ quantization without codebook-lookup could revolutionize generative modeling.

You can use it simply as follows. Will be dogfooded at MagViT2 pytorch port

import torch
from vector_quantize_pytorch import LFQ

# you can specify either dim or codebook_size
# if both specified, will be validated against each other

quantizer = LFQ(
    codebook_size = 65536,      # codebook size, must be a power of 2
    dim = 16,                   # this is the input feature dimension, defaults to log2(codebook_size) if not defined
    entropy_loss_weight = 0.1,  # how much weight to place on entropy loss
    diversity_gamma = 1.        # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
)

image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature=100.)  # you may want to experiment with temperature

# (1, 16, 32, 32), (1, 32, 32), ()

assert (quantized == quantizer.indices_to_codes(indices)).all()

You can also pass in video features as (batch, feat, time, height, width) or sequences as (batch, seq, feat)

import torch
from vector_quantize_pytorch import LFQ

quantizer = LFQ(
    codebook_size = 65536,
    dim = 16,
    entropy_loss_weight = 0.1,
    diversity_gamma = 1.
)

seq = torch.randn(1, 32, 16)
quantized, *_ = quantizer(seq)

assert seq.shape == quantized.shape

video_feats = torch.randn(1, 16, 10, 32, 32)
quantized, *_ = quantizer(video_feats)

assert video_feats.shape == quantized.shape

Or support multiple codebooks

import torch
from vector_quantize_pytorch import LFQ

quantizer = LFQ(
    codebook_size = 4096,
    dim = 16,
    num_codebooks = 4  # 4 codebooks, total codebook dimension is log2(4096) * 4
)

image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, entropy_aux_loss = quantizer(image_feats)

# (1, 16, 32, 32), (1, 32, 32, 4), ()

assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()

An improvised Residual LFQ, to see if it can lead to an improvement for audio compression.

import torch
from vector_quantize_pytorch import ResidualLFQ

residual_lfq = ResidualLFQ(
    dim = 256,
    codebook_size = 256,
    num_quantizers = 8
)

x = torch.randn(1, 1024, 256)

residual_lfq.eval()

quantized, indices, commit_loss = residual_lfq(x)

# (1, 1024, 256), (1, 1024, 8), (8)

quantized_out = residual_lfq.get_output_from_indices(indices)

# (1, 1024, 256)

assert torch.all(quantized == quantized_out)

Latent Quantization

Disentanglement is essential for representation learning as it promotes interpretability, generalization, improved learning, and robustness. It aligns with the goal of capturing meaningful and independent features of the data, facilitating more effective use of learned representations across various applications. For better disentanglement, the challenge is to disentangle underlying variations in a dataset without explicit ground truth information. This work introduces a key inductive bias aimed at encoding and decoding within an organized latent space. The strategy incorporated encompasses discretizing the latent space by assigning discrete code vectors through the utilization of an individual learnable scalar codebook for each dimension. This methodology enables their models to surpass robust prior methods effectively.

Be aware they had to use a very high weight decay for the results in this paper.

import torch
from vector_quantize_pytorch import LatentQuantize

# you can specify either dim or codebook_size
# if both specified, will be validated against each other

quantizer = LatentQuantize(
    levels = [5, 5, 8],      # number of levels per codebook dimension
    dim = 16,                   # input dim
    commitment_loss_weight=0.1,  
    quantization_loss_weight=0.1,
)

image_feats = torch.randn(1, 16, 32, 32)

quantized, indices, loss = quantizer(image_feats)

# (1, 16, 32, 32), (1, 32, 32), ()

assert image_feats.shape == quantized.shape
assert (quantized == quantizer.indices_to_codes(indices)).all()

You can also pass in video features as (batch, feat, time, height, width) or sequences as (batch, seq, feat)

import torch
from vector_quantize_pytorch import LatentQuantize

quantizer = LatentQuantize(
    levels = [5, 5, 8],
    dim = 16,
    commitment_loss_weight=0.1,  
    quantization_loss_weight=0.1,
)

seq = torch.randn(1, 32, 16)
quantized, *_ = quantizer(seq)

# (1, 32, 16)

video_feats = torch.randn(1, 16, 10, 32, 32)
quantized, *_ = quantizer(video_feats)

# (1, 16, 10, 32, 32)

Or support multiple codebooks

import torch
from vector_quantize_pytorch import LatentQuantize

model = LatentQuantize(
    levels = [4, 8, 16],
    dim = 9,
    num_codebooks = 3
)

input_tensor = torch.randn(2, 3, dim)
output_tensor, indices, loss = model(input_tensor)

# (2, 3, 9), (2, 3, 3), ()

assert output_tensor.shape == input_tensor.shape
assert indices.shape == (2, 3, num_codebooks)
assert loss.item() >= 0

Citations

@misc{oord2018neural,
    title   = {Neural Discrete Representation Learning},
    author  = {Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
    year    = {2018},
    eprint  = {1711.00937},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{zeghidour2021soundstream,
    title   = {SoundStream: An End-to-End Neural Audio Codec},
    author  = {Neil Zeghidour and Alejandro Luebs and Ahmed Omran and Jan Skoglund and Marco Tagliasacchi},
    year    = {2021},
    eprint  = {2107.03312},
    archivePrefix = {arXiv},
    primaryClass = {cs.SD}
}
@inproceedings{anonymous2022vectorquantized,
    title   = {Vector-quantized Image Modeling with Improved {VQGAN}},
    author  = {Anonymous},
    booktitle = {Submitted to The Tenth International Conference on Learning Representations },
    year    = {2022},
    url     = {https://openreview.net/forum?id=pfNyExj7z2},
    note    = {under review}
}
@inproceedings{lee2022autoregressive,
    title={Autoregressive Image Generation using Residual Quantization},
    author={Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},
    booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
    pages={11523--11532},
    year={2022}
}
@article{Defossez2022HighFN,
    title   = {High Fidelity Neural Audio Compression},
    author  = {Alexandre D'efossez and Jade Copet and Gabriel Synnaeve and Yossi Adi},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2210.13438}
}
@inproceedings{Chiu2022SelfsupervisedLW,
    title   = {Self-supervised Learning with Random-projection Quantizer for Speech Recognition},
    author  = {Chung-Cheng Chiu and James Qin and Yu Zhang and Jiahui Yu and Yonghui Wu},
    booktitle = {International Conference on Machine Learning},
    year    = {2022}
}
@inproceedings{Zhang2023GoogleUS,
    title   = {Google USM: Scaling Automatic Speech Recognition Beyond 100 Languages},
    author  = {Yu Zhang and Wei Han and James Qin and Yongqiang Wang and Ankur Bapna and Zhehuai Chen and Nanxin Chen and Bo Li and Vera Axelrod and Gary Wang and Zhong Meng and Ke Hu and Andrew Rosenberg and Rohit Prabhavalkar and Daniel S. Park and Parisa Haghani and Jason Riesa and Ginger Perng and Hagen Soltau and Trevor Strohman and Bhuvana Ramabhadran and Tara N. Sainath and Pedro J. Moreno and Chung-Cheng Chiu and Johan Schalkwyk and Franccoise Beaufays and Yonghui Wu},
    year    = {2023}
}
@inproceedings{Shen2023NaturalSpeech2L,
    title   = {NaturalSpeech 2: Latent Diffusion Models are Natural and Zero-Shot Speech and Singing Synthesizers},
    author  = {Kai Shen and Zeqian Ju and Xu Tan and Yanqing Liu and Yichong Leng and Lei He and Tao Qin and Sheng Zhao and Jiang Bian},
    year    = {2023}
}
@inproceedings{Yang2023HiFiCodecGV,
    title   = {HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec},
    author  = {Dongchao Yang and Songxiang Liu and Rongjie Huang and Jinchuan Tian and Chao Weng and Yuexian Zou},
    year    = {2023}
}
@article{Liu2023BridgingDA,
    title   = {Bridging Discrete and Backpropagation: Straight-Through and Beyond},
    author  = {Liyuan Liu and Chengyu Dong and Xiaodong Liu and Bin Yu and Jianfeng Gao},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2304.08612}
}
@inproceedings{huh2023improvedvqste,
    title   = {Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks},
    author  = {Huh, Minyoung and Cheung, Brian and Agrawal, Pulkit and Isola, Phillip},
    booktitle = {International Conference on Machine Learning},
    year    = {2023},
    organization = {PMLR}
}
@inproceedings{rogozhnikov2022einops,
    title   = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
    author  = {Alex Rogozhnikov},
    booktitle = {International Conference on Learning Representations},
    year    = {2022},
    url     = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@misc{shin2021translationequivariant,
    title   = {Translation-equivariant Image Quantizer for Bi-directional Image-Text Generation},
    author  = {Woncheol Shin and Gyubok Lee and Jiyoung Lee and Joonseok Lee and Edward Choi},
    year    = {2021},
    eprint  = {2112.00384},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{mentzer2023finite,
    title   = {Finite Scalar Quantization: VQ-VAE Made Simple},
    author  = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},
    year    = {2023},
    eprint  = {2309.15505},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yu2023language,
    title   = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
    author  = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
    year    = {2023},
    eprint  = {2310.05737},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{Zhao2024ImageAV,
  title     = {Image and Video Tokenization with Binary Spherical Quantization},
  author    = {Yue Zhao and Yuanjun Xiong and Philipp Krahenbuhl},
  year      = {2024},
  url       = {https://api.semanticscholar.org/CorpusID:270380237}
}
@misc{hsu2023disentanglement,
    title   = {Disentanglement via Latent Quantization}, 
    author  = {Kyle Hsu and Will Dorrell and James C. R. Whittington and Jiajun Wu and Chelsea Finn},
    year    = {2023},
    eprint  = {2305.18378},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{Irie2023SelfOrganisingND,
    title   = {Self-Organising Neural Discrete Representation Learning \`a la Kohonen},
    author  = {Kazuki Irie and R'obert Csord'as and J{\"u}rgen Schmidhuber},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:256901024}
}
@article{Huijben2024ResidualQW,
    title   = {Residual Quantization with Implicit Neural Codebooks},
    author  = {Iris Huijben and Matthijs Douze and Matthew Muckley and Ruud van Sloun and Jakob Verbeek},
    journal = {ArXiv},
    year    = {2024},
    volume  = {abs/2401.14732},
    url     = {https://api.semanticscholar.org/CorpusID:267301189}
}

vector-quantize-pytorch's People

Contributors

amirhm avatar dwromero avatar falkaer avatar hbenazha avatar kashif avatar kifarid avatar leedoyup avatar leng-yue avatar lijun-yu avatar lucidrains avatar matwilso avatar misterbourbaki avatar npuichigo avatar sekstini avatar theadamcolton avatar wesbz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

vector-quantize-pytorch's Issues

Training when input to be be quantized contains padding

When we perform quantization of a batch of sequences, with each sequence padded to the longest length in the batch, I would assume that while calculating and ordering the Euclidean distances for each embedding in a sequence, we would have to ignore the embeddings corresponding to the pad indices during training/updating the codebook. This is not explicitly defined in the code. Can this have a negative effect or is it safe to ignore?

Error when using gloo as DDP backend

Hello! Thank you for your great work on implementing VQ layer. When I use the VQ layer in DDP mode and use gloo as the backend as suggested in README, I got the following error:
terminate called after throwing an instance of 'gloo::EnforceNotMet' what(): [enforce fail at ../third_party/gloo/gloo/transport/tcp/pair.cc:510] op.preamble.length <= op.nbytes. 8773632 vs 8386560

Do you have any ideas on how to solve this problem?
I also tried to use nccl as the backend, however the program only hangs forever...

Loss and Backprop Details

Hi,

During training the vqvae backprops on multiple losses. While inputting feature maps to the model, we are given a loss, shoud I manually backpropagate and update weights through (the good ol' loss.backward() and optimizer.step()) this or is it handled implicitly?

Updating & Commitment Loss

Hi, I have some questions about recent commits.

  1. Recent commit modified to get quantized vector after updating 'self.embed'. Then, I wonder whether we should get new distances after updating.
    I mean, we use 'embed_ind' to get quantize, and calculate 'dist' to get 'embed_ind'. However, updated codebook might give different 'dist' and 'embed_ind'. How do you think about calculating dist again after updating codebook and before getting quantize vectors?

  2. Also, this commit uses 'l2norm_x' to get quantize using cosine sim. I think this was intended to make original input 'x' remains in its original domain, and gradient loss would be backpropagated through 'l2norm_x'. But still, commitment loss is calculated using original x and l2-normed quantize, which push input x to l2-normed hypersphere. Should we calculate commitment loss also with 'l2norm_x'?

Thanks in advance.

Crash on Mac M1/M2 chip when using MPS support

My environment:
Mac OS 13.2.1 with M2 Pro
Python 3.9.16
pytorch: 2.0.1
vector_quantize_pytorch: 1.6.11

site-packages/vector_quantize_pytorch/vector_quantize_pytorch.py:444: UserWarning: The operator 'aten::lerp.Scalar_out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at /Users/runner/work/temp/anaconda/conda-bld/pytorch_1682343673238/work/aten/src/ATen/mps/MPSFallback.mm:11.)
self.cluster_size.data.lerp
(cluster_size, 1 - self.decay)
libc++abi: terminating with uncaught exception of type c10::TypeError: Trying to convert ComplexFloat to the MPS backend but it does not have support for that dtype.
Exception raised from getMPSScalarType at /Users/runner/work/_temp/anaconda/conda-bld/pytorch_1682343673238/work/aten/src/ATen/native/mps/OperationUtils.mm:91 (most recent call first):
frame #0: at::native::mps::getMPSScalarType(c10::ScalarType) + 180 (0x1369364ac in libtorch_cpu.dylib)
frame #1: at::native::mps::mpsGraphRankedPlaceHolder(MPSGraph*, at::Tensor const&) + 96 (0x136938654 in libtorch_cpu.dylib)
frame #2: invocation function for block in at::native::mps::unary_op(at::Tensor const&, at::Tensor const&, std::__1::basic_string<char, std::__1::char_traits, std::__1::allocator >, MPSGraphTensor* (MPSGraph*, MPSGraphTensor*) block_pointer, std::__1::function<bool (at::Tensor const&)>) + 104 (0x136a00504 in libtorch_cpu.dylib)
frame #3: invocation function for block in at::native::mps::MPSGraphCache::CreateCachedGraph(std::__1::basic_string<char, std::__1::char_traits, std::__1::allocator > const&, at::native::mps::MPSCachedGraph* () block_pointer) + 216 (0x13694c58c in libtorch_cpu.dylib)
frame #4: _dispatch_client_callout + 20 (0x199392504 in libdispatch.dylib)
frame #5: _dispatch_lane_barrier_sync_invoke_and_complete + 56 (0x1993a1a9c in libdispatch.dylib)
frame #6: at::native::mps::MPSGraphCache::CreateCachedGraph(std::__1::basic_string<char, std::__1::char_traits, std::__1::allocator > const&, at::native::mps::MPSCachedGraph* () block_pointer) + 160 (0x13693a5d0 in libtorch_cpu.dylib)
frame #7: at::native::mps::unary_op(at::Tensor const&, at::Tensor const&, std::__1::basic_string<char, std::__1::char_traits, std::__1::allocator >, MPSGraphTensor* (MPSGraph*, MPSGraphTensor*) block_pointer, std::_1::function<bool (at::Tensor const&)>) + 860 (0x1369fff9c in libtorch_cpu.dylib)
frame #8: at::native::abs_out_mps(at::Tensor const&, at::Tensor&) + 124 (0x136a02bd0 in libtorch_cpu.dylib)
frame #9: at::ops::abs_out::call(at::Tensor const&, at::Tensor&) + 276 (0x1330e54a8 in libtorch_cpu.dylib)
frame #10: at::native::abs(at::Tensor const&) + 232 (0x132adba88 in libtorch_cpu.dylib)
frame #11: c10::impl::wrap_kernel_functor_unboxed
<c10::impl::detail::WrapFunctionIntoFunctor
<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&), &(torch::autograd::VariableType::(anonymous namespace)::abs(c10::DispatchKeySet, at::Tensor const&))>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&> >, at::Tensor (c10::DispatchKeySet, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) + 920 (0x134e13950 in libtorch_cpu.dylib)
frame #12: at::_ops::abs::call(at::Tensor const&) + 264 (0x1330e4900 in libtorch_cpu.dylib)
frame #13: torch::autograd::THPVariable_abs(_object*, _object*) + 188 (0x104b006d4 in libtorch_python.dylib)
frame #14: method_vectorcall_NOARGS + 172 (0x102a964e4 in python3.9)
frame #15: call_function + 516 (0x102b793cc in python3.9)
frame #16: _PyEval_EvalFrameDefault + 26296 (0x102b75ba8 in python3.9)
frame #17: _PyEval_EvalCode + 2804 (0x102b6ee98 in python3.9)
frame #18: _PyFunction_Vectorcall + 220 (0x102a89fe4 in python3.9)
frame #19: call_function + 516 (0x102b793cc in python3.9)
frame #20: _PyEval_EvalFrameDefault + 26332 (0x102b75bcc in python3.9)
frame #21: function_code_fastcall + 116 (0x102a8a0a0 in python3.9)
frame #22: method_vectorcall + 516 (0x102a8cf24 in python3.9)
frame #23: _PyEval_EvalFrameDefault + 27140 (0x102b75ef4 in python3.9)
frame #24: _PyEval_EvalCode + 2804 (0x102b6ee98 in python3.9)
frame #25: _PyFunction_Vectorcall + 220 (0x102a89fe4 in python3.9)
frame #26: _PyObject_FastCallDictTstate + 320 (0x102a89674 in python3.9)
frame #27: _PyObject_Call_Prepend + 164 (0x102a8a464 in python3.9)
frame #28: slot_tp_call + 116 (0x102afa5b0 in python3.9)
frame #29: _PyObject_MakeTpCall + 616 (0x102a893c4 in python3.9)
frame #30: call_function + 668 (0x102b79464 in python3.9)
frame #31: _PyEval_EvalFrameDefault + 26332 (0x102b75bcc in python3.9)
frame #32: function_code_fastcall + 116 (0x102a8a0a0 in python3.9)
frame #33: method_vectorcall + 516 (0x102a8cf24 in python3.9)
frame #34: _PyEval_EvalFrameDefault + 27140 (0x102b75ef4 in python3.9)
frame #35: _PyEval_EvalCode + 2804 (0x102b6ee98 in python3.9)
frame #36: _PyFunction_Vectorcall + 220 (0x102a89fe4 in python3.9)
frame #37: _PyObject_FastCallDictTstate + 320 (0x102a89674 in python3.9)
frame #38: _PyObject_Call_Prepend + 164 (0x102a8a464 in python3.9)
frame #39: slot_tp_call + 116 (0x102afa5b0 in python3.9)
frame #40: _PyObject_MakeTpCall + 616 (0x102a893c4 in python3.9)
frame #41: call_function + 668 (0x102b79464 in python3.9)
frame #42: _PyEval_EvalFrameDefault + 26456 (0x102b75c48 in python3.9)
frame #43: _PyEval_EvalCode + 2804 (0x102b6ee98 in python3.9)
frame #44: _PyFunction_Vectorcall + 220 (0x102a89fe4 in python3.9)
frame #45: call_function + 516 (0x102b793cc in python3.9)
frame #46: _PyEval_EvalFrameDefault + 26456 (0x102b75c48 in python3.9)
frame #47: _PyEval_EvalCode + 2804 (0x102b6ee98 in python3.9)
frame #48: _PyFunction_Vectorcall + 220 (0x102a89fe4 in python3.9)
frame #49: call_function + 516 (0x102b793cc in python3.9)
frame #50: _PyEval_EvalFrameDefault + 26456 (0x102b75c48 in python3.9)
frame #51: _PyEval_EvalCode + 2804 (0x102b6ee98 in python3.9)
frame #52: _PyFunction_Vectorcall + 220 (0x102a89fe4 in python3.9)
frame #53: call_function + 516 (0x102b793cc in python3.9)
frame #54: _PyEval_EvalFrameDefault + 26456 (0x102b75c48 in python3.9)
frame #55: _PyEval_EvalCode + 2804 (0x102b6ee98 in python3.9)
frame #56: run_mod + 216 (0x102bc8bdc in python3.9)
frame #57: pyrun_file + 264 (0x102bc6728 in python3.9)
frame #58: PyRun_SimpleFileExFlags + 1332 (0x102bc5f50 in python3.9)
frame #59: Py_RunMain + 2148 (0x102be88c0 in python3.9)
frame #60: pymain_main + 1252 (0x102be9a40 in python3.9)
frame #61: main + 56 (0x102a3c770 in python3.9)
frame #62: start + 2544 (0x1991efe50 in dyld)

Bug when using flag `orthogonal_reg_active_codes_only`

Hello, thank you for the great work of VQ-VAE.
While reading your implementation when turning on the flag orthogonal_reg_active_codes_only.

if self.orthogonal_reg_weight > 0:
codebook = self._codebook.embed
if self.orthogonal_reg_active_codes_only:
# only calculate orthogonal loss for the activated codes for this batch
unique_code_ids = torch.unique(embed_ind)
codebook = codebook[unique_code_ids]
num_codes = codebook.shape[0]
if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes]
codebook = codebook[rand_ids]

I think it does not work properly. The codebook shape after the line codebook = self._codebook.embed is [num_codebook, codebook_size, codebook_dim]. Therefore this line codebook = codebook[unique_code_ids] should be codebook = codebook[:, unique_code_ids] and num_codes = codebook.shape[0] should be num_codes = codebook.shape[1]. Am I correct?
Overall this above code should be

                if self.orthogonal_reg_active_codes_only:
                    # only calculate orthogonal loss for the activated codes for this batch
                    unique_code_ids = torch.unique(embed_ind)
                    codebook = codebook[:, unique_code_ids]

                num_codes = codebook.shape[1]
                if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
                    rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes]
                    codebook = codebook[:, rand_ids]

And even with the above code, orthogonal_reg_active_codes_only only work properly when num_codebook = 1

Looking forward to hearing your opinion.

EMA update on CosineCodebook

The original VIT-VQGAN paper does not seem to use EMA update for codebook learning since their codebook is unit-normalized vectors.

Particularly, to my understanding, EMA update does not quite make sense when the encoder outputs and codebook vectors are unit-normalized ones.

What's your take on this? Should we NOT use EMA update with CosineCodebook?

VectorQuantize not JIT-safe

The current code controls the program based on the values of tensors (like the "initted" buffer) and will not work when compiling a jit trace.

Cannot Converge with L2 Loss

I am trying to quantize the latent vector. To be specific, I use a Encoder to get the latent representation z of the input. Then I try to quantize z, then send z into Decoder.

However, during my experiment, I found the reconstruction loss cannot decrease with L2 loss, namely, the EuclideanCodebook. The model can converge with cosine similarity. Have any idea about this phenomenon?

I think cosine similarity only considers the direction of the vector, instead of the scale of the vector. I still want to use EuclideanCodebook.

Gradient of gumbel straight through multiplier was not correct.

In function gumbel_sample(vector_quantize_pytorch.py line 92).
st_mult = one_hot.gather(-1, rearrange(ind, '... -> ... 1')) # multiplier for straight-through
we gather the one hot to get the st_mult. and the mult it with gathered embedding.
Just like this:
onehot = gather(onehot) emb = gather(emb) emb = emb*onehot

But I found this is different in taming transformer codebase
emb = torch.einsum('bs,sd->b,s',onehot,emb)

It seems it is the same since the output of first code and second code was the same.

ButThe gradient of first code was not correct.
This is because even the most value of onehot is zero. Those zero will get gradient.
Just like ControlNet's zero conv if you are interested about it.

let's consider this:
x = [x1,x2]
y = [1,0]
o = x11+x20

the gradient to y2 = D o/D x2. and this is not zero!
Test on my own code:
image
image

Please let me know if i made any mistake, thanks again for your contributions.

Commitment Loss Problems

Hello,

First of all, thank you so much for this powerful implementation.

I have been researching to train some VQ-VAE to generate faces from FFHQ 128x128 and I always have the same problem if I use the commitment loss (0.25) and the gamma (0.99) like in the original paper, the commitment loss seems to grow infinitely. I know you said that it is an auxiliary loss and that is not that important but is this normal behavior? If not, how can I avoid for that to happen in the case I wanted to use this loss?

Thank you so much in advance!

RVQ loss

firstly,thanks to ur code,and then i have a quentison,when i use RVQ,it will return 8 loss,how do u solve this problem,u add all loss to one?

Quantizers are not DDP/AMP compliant

Hi Lucidrains,

Thanks for the amazing work you do by implementing all those papers!

Is there a plan to make the Quantizer be compliant with:

  • DDP - They need an all gather before calculating anything so the updates are exactly the same across all ranks
  • AMP - In my experience, if AMP touches upon the quantizers it screws up the gradient magnitudes making it NaN/Overflow

If you want I can have a go at it.

Hi,I meet some debug when i inference the model

What are the causes of the following problems?

Missing key(s) in state_dict: "quantizer.vq.layers.12._codebook.inited", "quantizer.vq.layers.12._codebook.cluster_size", "quantizer.vq.layers.12._codebook.embed", "quantizer.vq.layers.12._codebook.embed_avg", "quantizer.vq.layers.13._codebook.inited", "quantizer.vq.layers.13._codebook.cluster_size", "quantizer.vq.layers.13._codebook.embed", "quantizer.vq.layers.13._codebook.embed_avg", "quantizer.vq.layers.14._codebook.inited", "quantizer.vq.layers.14._codebook.cluster_size", "quantizer.vq.layers.14._codebook.embed", "quantizer.vq.layers.14._codebook.embed_avg", "quantizer.vq.layers.15._codebook.inited", "quantizer.vq.layers.15._codebook.cluster_size", "quantizer.vq.layers.15._codebook.embed", "quantizer.vq.layers.15._codebook.embed_avg", "quantizer.vq.layers.16._codebook.inited", "quantizer.vq.layers.16._codebook.cluster_size", "quantizer.vq.layers.16._codebook.embed", "quantizer.vq.layers.16._codebook.embed_avg", "quantizer.vq.layers.17._codebook.inited", "quantizer.vq.layers.17._codebook.cluster_size", "quantizer.vq.layers.17._codebook.embed", "quantizer.vq.layers.17._codebook.embed_avg", "quantizer.vq.layers.18._codebook.inited", "quantizer.vq.layers.18._codebook.cluster_size", "quantizer.vq.layers.18._codebook.embed", "quantizer.vq.layers.18._codebook.embed_avg", "quantizer.vq.layers.19._codebook.inited", "quantizer.vq.layers.19._codebook.cluster_size", "quantizer.vq.layers.19._codebook.embed", "quantizer.vq.layers.19._codebook.embed_avg", "quantizer.vq.layers.20._codebook.inited", "quantizer.vq.layers.20._codebook.cluster_size", "quantizer.vq.layers.20._codebook.embed", "quantizer.vq.layers.20._codebook.embed_avg", "quantizer.vq.layers.21._codebook.inited", "quantizer.vq.layers.21._codebook.cluster_size", "quantizer.vq.layers.21._codebook.embed", "quantizer.vq.layers.21._codebook.embed_avg", "quantizer.vq.layers.22._codebook.inited", "quantizer.vq.layers.22._codebook.cluster_size", "quantizer.vq.layers.22._codebook.embed", "quantizer.vq.layers.22._codebook.embed_avg", "quantizer.vq.layers.23._codebook.inited", "quantizer.vq.layers.23._codebook.cluster_size", "quantizer.vq.layers.23._codebook.embed", "quantizer.vq.layers.23._codebook.embed_avg".

Eval mode

Hi,
It seems like there's a problem when using soundstream.eval() due to the RVQ part. Not a big deal but I wanted to let you know about that

Missing feature to reproduce SoundStream's Residual Vector Quantizer

Hi,
Thanks for this cool work!
I couldn't help but notice that a few features used to improve the usage of the codebooks were missing to be an exact implementation of the work done in the SoundStream article.

  • "First, instead of using a random initialization for the codebook vectors, we run the k-means algorithm on the first training batch and use the learned centroids as initialization"
  • "Second, as proposed in [34], when a codebook vector has not been assigned any input frame for several batches, we replace it with an input frame randomly sampled within the current batch."
    I'm currently working on an implementation of this work, I'll use and adapt your code for this purpose but thought you might want to know about it.
    I'll keep you posted :)

quantize_dropout is not compatible with accept_image_fmap

line 98 of residual_vq.py is as follows,

null_indices = torch.full((b, n), -1., device = device, dtype = torch.long)

I believe it should be something like this, since the (b, n) shape is not applicable when accept_image_fmap=True?

null_indices = torch.full(all_indices[0].shape, -1., device = device, dtype = torch.long)

[Feature Request] Reservoir Restart & Batch Normalisation Before Flattening

There are two interesting features (low implementation overhead) from the paper Robust Training of Vector Quantized Bottleneck Models:

  • 3.A Importance of proper scaling - Batch normalisation
    • This is in line with the Orthogonal Regularisation, but I am not sure of its interaction with it. Furthermore it will need DDP handling by using SyncBatchNorm more specifically the SyncBatchNorm.convert_sync_batchnorm() method
  • 3.B Batch data-dependent codebook updates
    • This is an extended logic improvement of the K-mean initialization and of the Code Restarts.
    • Useful code is here

Missing parameter of beta

Hi, in the original VQVAE paper, the commit_loss is defined as

(quantize.detach()-x) ** 2 + beta * (quantize - x.detach() ** 2)

where the beta is usually to be 0.25. But the commit_loss is defined as the following in your implementation:

F.mse_loss(quantize.detach(), x)

So I wonder if the parameter beta is set to be 1 by default or if the second term is missing? Thank you very much.

Replace also the embed_avg?

Hey, I'm new in VQ subject so maybe i got it wrong but i have a question about the replacement mechanism in the repo.

When we replace a vector in the codebook with a vector from the batch, shouldn't we also replace the corresponding index in the "embed_avg" (with the same vector) and kind of reset it? It makes to me no sense to average on the old vector on the next iteration.
The same logic imply on the cluster size, why to keep the old cluster size if we now replaced the vector? shouldn't we reset it to 1 or maybe even better, let the next iteration determinate it's start point cluster size (like in the init k-means). otherwise, it can "kill" a spot in the codebook if we will replace it several times in a raw with no good sample.

If I'm completely wrong and there is an info about it in this repo or papers i would love to know.

Thanks a lot for this repo,
Amitai.

The way self.embed_avg is being computed

In the current implementation the self.embed_avg is being computed as follows (code)
embed_avg = ema(embed_avg, embed_sum) / (laplace_smoothing(self.cluster_size) * self.cluster_size.sum()), where embed_sum is a tensor of shape [n_classes, hidden_size] and all batch x time samples of current batch that belong to the same class are being summed up. Since we want a mean, not a sum, we divide by self.cluster_size, but since some values might be equal to 0, we introduce the laplace_smoothing, that in its turn introduces the divition by self.cluster_size.sum() which we invalidate by further multiplying by it.

My question is, would not it be better to compute embed_avg as embed_avg = ema(embed_avg, torch.where(self.cluster_size == 0, 0, embed_sum / self.cluster_size.unsqueeze(dim=-1))) It seems more accurate since we divide embed_sum, not the ema(embed_avg, embed_sum) by cluster_size and there are no epsilons that are appear in both numerator and denominator.

Maybe I just did not get some intuition behind your code, could you tell me please, why did you write it the way it is? Is it referring to an algorithm from some paper? Or have you taken this code from some other repository?

Explain DDP example

Hi can you explain the DDP example? Do you know what is needed for it to work with pytorch lightning & tpu?

Why L2 Normed l2 distance is cosine?

I'm new to vqgan, but in the improved-vqgan, they use l2 distance with l2 normed embeddings.

So I am confused that, the l2 normed l2 distance is not equal to cosine similarity. So why do we say they boil down to using cosine similarity for the distance?

orthogonal regularization loss useless?

because the codebooks are not registered as trainable parameters, and the orthogonal loss is only a function of the codebooks, is the orthogonal loss entirely useless?

codebook initialization

Hi, Thank you for this great work. It's quite useful!

I have been having problems with index collapse and I'm not sure where it's coming from. But upon digging into the code, it seems that when we're not using k-means to initialize the codebook vectors, randn (normal distribution) is used to initialize them. The vqvae paper specifically uses uniform distribution for initialization, which allows the authors to ignore KL divergence when training.

This is from the vqvae paper: "Since we assume a uniform prior for z, the KL term that usually appears in the ELBO is constant w.r.t. the encoder parameters and can thus be ignored for training."

Is there any reason why you changed to Normal distribution here?

Thanks!

kmeans and ddp hangs

kmeans and ddp hangs for me. ddp is initialized by pytorch lightning in my case. I have several questions:

In https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/vector_quantize_pytorch.py#L98

all_num_samples = all_gather_sizes(local_samples, dim = 0) should it be dim = 1 (as dim 0 is the codebook dimension)?

Then in https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/vector_quantize_pytorch.py#L93 it just hangs for me. I am not totally sure, but I believe distributed.broadcast in

https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/vector_quantize_pytorch.py#L90

is called with incompatible shapes. See https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast

tensor must have the same number of elements in all processes participating in the collective.

No way of training the codebook

Hi!
Could you please explain how the codebook vectors are updated if the codebook vectors are not required to be orthogonal?

  1. embed tensors in both Euclidean and CosineSim codebooks are registered as buffers, so they can't be updated at all
  2. There is no loss on the codebook vectors that moves them closer to the input

Am I missing something? It seems that right now there is no way of updating the codebook vectors without the orthogonal loss.

zero for second residual grad

Thanks for you jobs. When we checked the code, we found that there was no gradient for residual layer after second layer, please confirm it.

we change the code to : residual = residual - quantized ---> residual = residual - quantized.detach()

image

Here's the verification we did

    if __name__ == "__main__":
          quantizer = ResidualVQ(
              num_quantizers=4, dim=256, codebook_size=16,
              kmeans_init=True, kmeans_iters=10, threshold_ema_dead_code=2, channel_last=False,
          )
  
          for i in range(4):
              input = torch.rand((2, 256, 30), requires_grad=True)
              quantized, indices, losses = quantizer(input)
              print(quantized.shape, indices.shape, losses.shape)
  
              losses[0, i].backward()
              print(input.grad)

Memory leak on 1.6.14

I'm encountering a memory leak after updating to 1.6.14 from 1.6.11.
At first, I noticed that my training code, which was previously working just fine, is now encountering an OOM error after several iterations, which is weird.
I can confirm the memory leak issue come from this library, because when I downgraded the library back to 1.6.11, the OOM error were no longer present.

The model follows this structure:

self.encoder = Encoder(...)
self.vq = GroupedResidualVQ(
    dim=dim_emb,
    num_quantizers=num_codebooks,
    codebook_size=n_emb,
    groups=vq_groups,
    decay=vq_decay,
    commitment_weight=vq_commitment_weight,
    quantize_dropout_multiple_of=1,
    quantize_dropout=True,
    quantize_dropout_cutoff_index=vq_quantize_dropout_cutoff_index,
    kmeans_init=True,
    threshold_ema_dead_code=2,
    stochastic_sample_codes=vq_stochastic_sample_codes,
)
self.decoder = Decoder(...)

And the snippet of the training code is like this:

x = model.encoder(x)
quantized, _, commit_loss = model.vq(x)
recon_x = model.decoder(quantized)

For context, I trained this on a Google Colab environment.

Vit-VQGAN

Hello. Thanks for great repo. Is Vit-VQGAN is implemented in this repo? Looks like it is far away from cnn decoders/encoders in performance/ quality

How to train this?

Hi, I want to use this package to experiment with data different than images (multivariate time series).
I see that the commitment_loss that is returned is not a tensor, but rather a built in float, hence it's not possible to backprop through it.

For now i didn't modify any of my other loss calculation code, i just plugged in the quantizer at the beginning of my architecture, but i'd like to be sure if this is the correct way to go about this.

Thanks and keep up, you're doing god's work with your repositories!

It seems that the vector from the GPU cannot be input

import torch
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
    dim = 256,
    codebook_size = 512,     # codebook size
    decay = 0.8,             # the exponential moving average decay, lower means the dictionary will change faster
    commitment_weight = 1.   # the weight on the commitment loss
)

x = torch.randn(1, 1024, 256)
x = x.cuda()
quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)

File "/vector_quantize_pytorch/vector_quantize_pytorch.py", line 280, in forward
dist = -torch.cdist(flatten, embed, p = 2)
File "/torch/functional.py", line 1153, in cdist
return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined]
RuntimeError: X1 and X2 must have the same device type. X1: cuda X2: cpu
How to change it to adapt to GPU?
Thanks!

RQ-VAE: How can I get a list of all learned codebook vectors (as indexed in the "indices")?

Hi Lucid,
i am working on quantizing CLIP image embeddings with your RQ-VAE. It works pretty well.

Next I want to take all learned codebook vectors and add them to the vocab of a GPT (as frozen token embeddings).

The idea is to train a GPT with CLIP image embeddings in between texts, e.g. IMAGE-CAPTION or TEXT-IMAGE-TEXT-IMAGE- ... Flamingo-style).

If this works, then GPT could maybe also learn to generate quantized CLIP IM embeddings token by token --> and then e.g. show images through a.) retrieval or b.) a DALLE 2 decoder :)

... So my question is: Once the RQ-VAE is trained and i can get the quantized reconstructions and indices - How can I get a list or tensor of the actual codebook? (all possible vectors from the rq-vocab) :)

Codebook Update with Cosine Similarity

In CosineSimCodebook forward function line#449, is there any reason why we use 'bins' directly instead of 'self.cluster_size'?
It seems EuclideanCodebook uses laplace smoothed 'self.cluster_size' as line#308, but CosineSimCodebook doesn't.
There exists process to update 'self.cluster_size' in line#441, but it seems we never use it in CosineSimCodebook class(except expiring).

Plugging vector-quantize-pytorch into taming-transformers

Hi,

I noticed your architecture could be plugged within the pipeline from https://github.com/CompVis/taming-transformers. I have proposed a code here (https://github.com/tanouch/taming-transformers) doing that. It enables to properly compare the different features proposed in your repo (Lower codebook dimension, Cosine similarity, Orthogonal regularization loss, etc) with the original formulation.

The code from this repo can be seen in both files

  • taming-transformers/taming/models/vqgan.py
  • taming-transformers/taming/modules/vqvae/quantize.py

As you can see, it is easy to launch a large scale training with your proposed architecture.

I am not sure this issue belongs here or in the taming-transformers repo. However, I thought you might be interested.
Thanks again for your work and these open-sourced repositeries !

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.