vqvae from[1711.00937] Neural Discrete Representation Learning, for unsupervised learning of discrete representations, is still in use in the field of multimodal generation. Here is the code
VQVAE
vqvae reason itself is very simple, it is proposed with pixelcnn, autoregressive model is closely related, like vae, gan this kind of generative model, they are more like the whole data for the estimation, and autoregressive model is related to the series model, is more like the generation of the data distribution modeling
Autoregressive models condition their predictions on prior values in the series, rather than on underlying random variables. Thus, they attempt to explicitly model the data generating distribution rather than approximate it
poixelcnn is an autoregressive model, and each time it is from the discrete results obtained by vqvae sampling sequential generation of results, in order to achieve this effect utilizes a masked convolution, the convolution weight behind the part of the 0, so that in the convolution does not pay attention to the results of the latterToyPixelCNN.ipynb at master – pilipolio/learn-pytorch
1 |
class MaskedConv(nn.Conv2d): |
Many models, including transformers, are now auto-regressive, whereas GAN and VAE are not, and their drawback is that they are difficult to model discrete data. Vqvae compensates for this.
The point of VQVAE is to design a discrete dictionary and then use a technique to conduct the gradient so that the dictionary can be updated.
This design is called pass-through estimator, which passes the gradient obtained by the decoder directly to the encoder.Assuming that the codebook’s shape is [codebook_size,codebook_dim], and the input feature’s shape is [size,codebook_dim], their distance is obtained by a metric (can be Usetorch.cdist
) to get [size,codebook_size], which is equivalent to getting the corresponding position on the dictionary for each position on the feature.
1 |
# 写法1 |
The embedded features are obtained based on the nearest distance
1 |
# 写法1 |
one-hot
The shape of the encoder is [encode_size,embed_size], the third term in the following equation is the commitment loss, which is used to update the encoder output, the third term is used to update the dictionary
To learn the embedding space, one of the simplest dictionary learning algorithms, vector quantization ( VQ ), is used.The VQ objective uses the l2 error to move the embedding vector ei to the encoder output ze ( x )
1 |
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) |
Additionally the dictionary can be updated using EMA
Here the update logic is, each time the ema_cluster_size is updated, for each embedded vector, the number of nearest feature vectors is obtained, which is updated by ema, and the weight is the value of each embedding which is updated by ema.
1 |
# Update weights with EMA |
VQVAE-2
It is simply a multi-scale vqvae, designed with multiple encoder-codelayer-decoder.
Firstly, the features are downsampled by multiple encoders to get different scale features, then the different scale features are quantized, and the features obtained after quantization are up-sampled and then decoder to get multi-scale features. In addition, the paper also proposes to change the dimension of codebook from 256 to 32 to keep the same reconstruction effect, and at the same time, the decoded features and codebook do l2-norm, and use the cos similarity to judge.
Residual VQ
The logic is very simple – quantize(x-quantize(x-quantize(x-…))))
SIMVQ
According to the authors of the paper, dimension transformation on codebook improves the utilization of the coding table, resulting in better performance on many optimizers
You can see that the above code often use einops and einx and torch einsum operation, these are very convenient library or function. Here is the introduction
Common operations in einops
rearrange
The most commonly used is rearrange, which can be used to convert the order of axis, composition, decomposition, and so on.
1 |
x = torch.randn(10,20,10,10) |
reduce
1 |
# yet another example. Can you compute result shape? |
It can be used for finding the mean, maxpooling, and so on.
1 |
ims = torch.randn((10,20,30,30))*10-2 |
By using the()
Keep dim, or you can also use the1
1 |
data = torch.randn(10,20,30,40) |
stack and concatenation
1 |
# rearrange can also take care of lists of arrays with the same shape |
Converting the list size dimension in a list’s tensor
1 |
c = list() |
Or find all the tensor sums, max, etc. in a list
1 |
c = list() |
add or remove axis
1 |
x = rearrange(x,'b h w c -> b 1 h w 1 c') |
channel shuffle
1 |
c = torch.randn(10,30,10,10) |
repeat
1 |
repeat(x,'b h w c -> b (h 2) (w 2) c') |
split dimension
1 |
c = torch.randn(10,30,10,10) |
Split has a different approach
1 |
y1, y2 = rearrange(x, 'b (split c) h w -> split b c h w', split=2) |
y1 = x[:, :x.shape[1] // 2, :, :]
y1 = x[:, 0::2, :, :]
striding anything
1 |
# each image is split into subgrids, each subgrid now is a separate "image" |
You can see that the most commonly used function isrearrange
,reduce
as well asrepeat
, essentially replacing the originalsum
,transpose
,expand
,reshape
and other torch operations
parse_shape
pass (a bill or inspection etc)parse_shape
, which corresponds to more conveniently obtaining the desired dimension size
1 |
y = np.zeros([700]) |
pack and unpack
The pack is a way to put together some dimensions from some columns of data
1 |
h,w = 100,200 |
1 |
unpacked_rgb,unpacked_depth = unpack(img_rgbd,ps,"h w *") |
Using layers in conjunction with torch
1 |
from einops.layers.torch import Rearrange,Reduce |
Einx
sort oftorch.einsum
The calculation of the einsumeinsum tutorialis a convenient way to compute the product of multiple tensors, while Einx facilitates the writing of MLP-based architectural code by constructing mlp via weight_shape and bias_shape combined with pattern
1 |
from einops.layers.torch import EinMix as Mix |
For what it’s worth, einops also has einsum
1 |
from einops import einsum, pack, unpack |
Relevant information
- MishaLaskin/vqvae: A pytorch implementation of the vector quantized variational autoencoder (https://arxiv.org/abs/1711.00937)
- VQ-VAE/vq_vae/auto_encoder.py at master · nadavbh12/VQ-VAE
- VQ-VAE/vqvae.py at main · AndrewBoessen/VQ-VAE
- vqvae-2/vqvae.py at main · vvvm23/vqvae-2
- Autoregressive Models in Deep Learning — A Brief Survey | George Ho
- lucidrains/vector-quantize-pytorch: Vector (and Scalar) Quantization, in Pytorch
- A brief introduction to VQ-VAE: quantized self-encoder – Scientific Spaces|Scientific Spaces
- Spinning tricks for VQ: a general generalization of gradient passthrough estimation – Scientific Spaces|Scientific Spaces
- Another trick for VQ: add a linear transformation to the coding table – Scientific Spaces|Scientific Spaces
- Writing better code with pytorch+einops
- Residual Vector Quantisation – Notes by Lex
- rese1f/Awesome-VQVAE: A collection of resources and papers on Vector Quantized Variational Autoencoder (VQ-VAE) and its application