vqvae及其变体代码学习

vqvae from[1711.00937] Neural Discrete Representation Learning, for unsupervised learning of discrete representations, is still in use in the field of multimodal generation. Here is the code

VQVAE

vqvae reason itself is very simple, it is proposed with pixelcnn, autoregressive model is closely related, like vae, gan this kind of generative model, they are more like the whole data for the estimation, and autoregressive model is related to the series model, is more like the generation of the data distribution modeling

Autoregressive models condition their predictions on prior values in the series, rather than on underlying random variables. Thus, they attempt to explicitly model the data generating distribution rather than approximate it

poixelcnn is an autoregressive model, and each time it is from the discrete results obtained by vqvae sampling sequential generation of results, in order to achieve this effect utilizes a masked convolution, the convolution weight behind the part of the 0, so that in the convolution does not pay attention to the results of the latterToyPixelCNN.ipynb at master – pilipolio/learn-pytorch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class MaskedConv(nn.Conv2d):
def __init__(self, mask_type, *args, **kwargs):
super(MaskedConv, self).__init__(*args, **kwargs)
self.mask_type = mask_type
self.register_buffer('mask', self.weight.data.clone())

channels, depth, height, width = self.weight.size()

self.mask.fill_(1)
if mask_type =='A':
self.mask[:,:,height//2,width//2:] = 0
self.mask[:,:,height//2+1:,:] = 0
else:
self.mask[:,:,height//2,width//2+1:] = 0
self.mask[:,:,height//2+1:,:] = 0


def forward(self, x):
self.weight.data *= self.mask
return super(MaskedConv, self).forward(x)

Many models, including transformers, are now auto-regressive, whereas GAN and VAE are not, and their drawback is that they are difficult to model discrete data. Vqvae compensates for this.

The point of VQVAE is to design a discrete dictionary and then use a technique to conduct the gradient so that the dictionary can be updated.

This design is called pass-through estimator, which passes the gradient obtained by the decoder directly to the encoder.Assuming that the codebook’s shape is [codebook_size,codebook_dim], and the input feature’s shape is [size,codebook_dim], their distance is obtained by a metric (can be Usetorch.cdist) to get [size,codebook_size], which is equivalent to getting the corresponding position on the dictionary for each position on the feature.

1
2
3
4
5
6
7
8
9
10
# 写法1
dist_manual = torch.sqrt(
torch.sum(x ** 2, dim=1, keepdim=True) +
torch.sum(y ** 2, dim=1, keepdim=True).t() -
2 * x @ y.t()
)
# 写法2 better readable and efficient since no gradient computation
with torch.no_grad():
dist = torch.cdist(x, implicit_codebook)
indices = dist.argmin(dim = -1)

The embedded features are obtained based on the nearest distance

1
2
3
4
5
6
7
8
# 写法1  
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) # (encoded_feat size,1)
min_encodings = torch.zeros(
min_encoding_indices.shape[0], self.n_e, device=z.device) # (encoded_feat size,embedding_size)
min_encodings.scatter_(1, min_encoding_indices, 1) # one-hot like
# 写法2 dry and more clean
min_encoding_indices = torch.argmin(d, dim=1)
my_min_encodings = F.one_hot(min_encoding_indices.squeeze())

one-hotThe shape of the encoder is [encode_size,embed_size], the third term in the following equation is the commitment loss, which is used to update the encoder output, the third term is used to update the dictionary

To learn the embedding space, one of the simplest dictionary learning algorithms, vector quantization ( VQ ), is used.The VQ objective uses the l2 error to move the embedding vector ei to the encoder output ze ( x )

1
2
3
4
5
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
loss = self.beta * torch.mean((z_q - z.detach()) ** 2))+torch.mean(((z_q.detach() - z) ** 2)
z_q = z + (z_q - z).detach()
# torch.mean((z_q-z.detach())**2) 可以更简单地写为
F.mse_loss(z_q,z_e.detach())

Additionally the dictionary can be updated using EMA

Here the update logic is, each time the ema_cluster_size is updated, for each embedded vector, the number of nearest feature vectors is obtained, which is updated by ema, and the weight is the value of each embedding which is updated by ema.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Update weights with EMA
if self.training:
self._ema_cluster_size = self._ema_cluster_size * self._decay + (
1 - self._decay
) * torch.sum(encodings, 0)

# Laplace smoothing
n = torch.sum(self._ema_cluster_size.data)
self._ema_cluster_size = (
(self._ema_cluster_size + self._epsilon)
/ (n + self._n_embeddings * self._epsilon)
* n
)

dw = torch.matmul(encodings.t(), flat_z_e)
self._ema_w = nn.Parameter(
self._ema_w * self._decay + (1 - self._decay) * dw
)

self._embedding.weight = nn.Parameter(
self._ema_w / self._ema_cluster_size.unsqueeze(1)
)

VQVAE-2

It is simply a multi-scale vqvae, designed with multiple encoder-codelayer-decoder.
Firstly, the features are downsampled by multiple encoders to get different scale features, then the different scale features are quantized, and the features obtained after quantization are up-sampled and then decoder to get multi-scale features. In addition, the paper also proposes to change the dimension of codebook from 256 to 32 to keep the same reconstruction effect, and at the same time, the decoded features and codebook do l2-norm, and use the cos similarity to judge.

Residual VQ

The logic is very simple – quantize(x-quantize(x-quantize(x-…))))

SIMVQ

According to the authors of the paper, dimension transformation on codebook improves the utilization of the coding table, resulting in better performance on many optimizers

You can see that the above code often use einops and einx and torch einsum operation, these are very convenient library or function. Here is the introduction

Common operations in einops

rearrange

The most commonly used is rearrange, which can be used to convert the order of axis, composition, decomposition, and so on.

1
2
3
4
5
6
7
8
9
x = torch.randn(10,20,10,10)
# order
y = rearrange(x,'b c h w -> b h w c')
print(y.shape)
# composition
y = rearrange(x,'b c h w -> b c (h w)')
# decomposition
y = rearrange(y,'b c (h w) -> b h w c')
y = rearrange(y,'(b1 b2) h w c -> b1 b2 h w c',b1=2)

reduce

1
2
# yet another example. Can you compute result shape?
reduce(ims, "(b1 b2) h w c -> (b2 h) (b1 w)", "mean", b1=2)

It can be used for finding the mean, maxpooling, and so on.

1
2
3
4
5
6
7
8
9
10
ims = torch.randn((10,20,30,30))*10-2
b,c,h,w = ims.shape
m_ims = reduce(ims,'b c h w -> b c',"min")
print(m_ims.shape)

m_ims = reduce(ims,'b c h w -> b (h w) c','min').transpose(1,2).reshape(b,c,h,w)
print(m_ims.shape)
print(ims == m_ims)
min2_ims = reduce(ims,'b c (h h2) (w w2) -> b c h w','mean',h2=2,w2=2)
reduce(ims,'b (h h2) (w w2) c -> h (b w) c',"max",h2=2,w2=2)

By using the()Keep dim, or you can also use the1

1
2
3
4
5
6
7
data = torch.randn(10,20,30,40)
mean_ = reduce(data,'b c h w -> b c () ()','mean') # 求均值
ans = data.mean(dim=[2,3],keepdim=True)
print((((ans-mean_)<1e-6).float()).mean())

max_pool = reduce(data,'b c (2 h) (2 w) -> b c h w','max') #max pooling
adaptive_max_pool = reduce(data,'b c h w -> b c ()','max')

stack and concatenation

1
2
3
4
5
6
# rearrange can also take care of lists of arrays with the same shape
x = list(ims)
print(type(x), "with", len(x), "tensors of shape", x[0].shape)
# that's how we can stack inputs
# "list axis" becomes first ("b" in this case), and we left it there
rearrange(x, "b h w c -> b h w c").shape

Converting the list size dimension in a list’s tensor

1
2
3
4
5
6
c = list()
c.append(torch.randn(10,20,30))
c.append(torch.randn(10,20,30))

rearrange(c,'l c h w -> c l h w').shape

Or find all the tensor sums, max, etc. in a list

1
2
3
4
5
6
7
8
c = list()
c.append(torch.randn(10,20,30))
c.append(torch.randn(10,20,30))

rearrange(c,'l c h w -> c l h w').shape
reduce(c,'c l h w -> l h w','mean').shape
reduce(c,'c l h w -> l h w','sum').shape
reduce(c,'c l h w -> l h w','max').shape

add or remove axis

1
2
x = rearrange(x,'b h w c -> b 1 h w 1 c')
y = rearrange(y,'b h w c - b h (w c)')

channel shuffle

1
2
c = torch.randn(10,30,10,10)
rearrange(c,'b (g1 g2 c) h w -> b (g2 g1 c) h w',g1=3,g2=5).shape

repeat

1
2
3
repeat(x,'b h w c -> b (h 2) (w 2) c')
repeat(x,'h w c -> h new_axis w c',new_axis=5)
repeat(x,'h w c -> h 5 w c')

split dimension

1
2
3
c = torch.randn(10,30,10,10)
x,y,z = rearrange(c,'b (head c) h w -> head b c h w',head=3)
print(x.shape,y.shape,z.shape)

Split has a different approach

1
2
3
y1, y2 = rearrange(x, 'b (split c) h w -> split b c h w', split=2)
result = y2 * sigmoid(y2) # or tanh
y1, y2 = rearrange(x, 'b (c split) h w -> split b c h w', split=2)
  • y1 = x[:, :x.shape[1] // 2, :, :]
  • y1 = x[:, 0::2, :, :]

striding anything

1
2
3
4
5
6
7
# each image is split into subgrids, each subgrid now is a separate "image"
y = rearrange(x, "b c (h hs) (w ws) -> (hs ws b) c h w", hs=2, ws=2)
y = convolve_2d(y)
# pack subgrids back to an image
y = rearrange(y, "(hs ws b) c h w -> b c (h hs) (w ws)", hs=2, ws=2)

assert y.shape == x.shape

You can see that the most commonly used function isrearrange,reduceas well asrepeat, essentially replacing the originalsum,transpose,expand,reshapeand other torch operations

parse_shape

pass (a bill or inspection etc)parse_shape, which corresponds to more conveniently obtaining the desired dimension size

1
2
y = np.zeros([700])
rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape

pack and unpack

The pack is a way to put together some dimensions from some columns of data

1
2
3
4
5
6
h,w = 100,200
import numpy as np
img_rgb = np.random.random([h,w,3])
img_depth = np.random.random([h,w])
img_rgbd,ps = pack([img_rgb,img_depth],'h w *')
print(img_rgbd.shape,ps)
1
2
unpacked_rgb,unpacked_depth = unpack(img_rgbd,ps,"h w *")
print(unpacked_rgb.shape,unpacked_depth.shape)

Using layers in conjunction with torch

1
from einops.layers.torch import Rearrange,Reduce

Einx

sort oftorch.einsumThe calculation of the einsumeinsum tutorialis a convenient way to compute the product of multiple tensors, while Einx facilitates the writing of MLP-based architectural code by constructing mlp via weight_shape and bias_shape combined with pattern

1
2
3
4
from einops.layers.torch import EinMix as Mix
mlp = Mix('t b c-> t b c_out',weight_shape='c c_out',c=10,c_out=20)
x = torch.randn(10,30,10)
y = mlp(x)

For what it’s worth, einops also has einsum

1
2
3
4
from einops import einsum, pack, unpack
# einsum is like ... einsum, generic and flexible dot-product
# but 1) axes can be multi-lettered 2) pattern goes last 3) works with multiple frameworks
C = einsum(A, B, 'b t1 head c, b t2 head c -> b head t1 t2')

Relevant information

  1. MishaLaskin/vqvae: A pytorch implementation of the vector quantized variational autoencoder (https://arxiv.org/abs/1711.00937)
  2. VQ-VAE/vq_vae/auto_encoder.py at master · nadavbh12/VQ-VAE
  3. VQ-VAE/vqvae.py at main · AndrewBoessen/VQ-VAE
  4. vqvae-2/vqvae.py at main · vvvm23/vqvae-2
  5. Autoregressive Models in Deep Learning — A Brief Survey | George Ho
  6. lucidrains/vector-quantize-pytorch: Vector (and Scalar) Quantization, in Pytorch
  7. A brief introduction to VQ-VAE: quantized self-encoder – Scientific Spaces|Scientific Spaces
  8. Spinning tricks for VQ: a general generalization of gradient passthrough estimation – Scientific Spaces|Scientific Spaces
  9. Another trick for VQ: add a linear transformation to the coding table – Scientific Spaces|Scientific Spaces
  10. Writing better code with pytorch+einops
  11. Residual Vector Quantisation – Notes by Lex
  12. rese1f/Awesome-VQVAE: A collection of resources and papers on Vector Quantized Variational Autoencoder (VQ-VAE) and its application