Variational Autoencoders

Generative models in deep learning become popular since 2014, when GAN (Generative Adversial Nets) was introduced by Ian Goodfellow. Another model that is widely used for generating data is VAE (Variational Autoencoders) and its derivation CVAE (Conditional Variational Autoencoders).

In this post, we will have a brief introduction to VAE and how it works.

The intuition

Latent variable in Autoencoder model is an indispensable part. It stores the information that is needed to reconstruct the original data, and this is what we can use to generate new things. We could use VAE to generate images, languages and even music!

Let’s first define several notions before we talk about the generative process:

  • $X$, which is the data from the dataset and we want to learn from.
  • $z$, which is the latent variable stores information.
  • $P(X)$, probability distribution of all the data.
  • $P(z)$, probability distribution of the latent variables.
  • $P(X|z)$, probability distribution of generated data given the latent variable z.

The generative process would be: $P(X) = {\displaystyle \int }P(X|z)P(z)$. The intuition behind this framework is that, if it can reconstruct the training samples, it would be possible for it to produce simialr samples. Since normal distribution $N(0, I)$ can be used to approximate any distribution within the same dimension, we take $P(z)=N(0,I)$ and all we need to do is to maximize the equation above.

The model

As its name suggests, Autoencoders containt two parts: $Q(z|X)$ which encode the data sample $X$ onto the latent $z$, and $P(X|z)$ which decodes the latent $z$ and try to reconstruct the original data. Hopefully, we want the latent $z$ produced by our encoder net can best produced $X$. However, how can we make sure that our $Q(z|X)$ would be close to the true distribution $Q_{real}(z|X)$. Here is where the KL-divergence comes from.

Here, we won’t work through the entire proof and equations of how it works, but basically it measures how different are two distribution. Therefore, we could use it as our loss function, and we can use grdient descent later to optimize it. If you want to explore about how it works in VAE, here is the thing you should check out.

The loss function

The VAE objective function is:

$log(P(X)) - D[Q(z|X)\vert\vert P(z|X)] = E[log(P(X|z))] - D[Q(z)\vert \vert P(z)]$

On the left hand side, it is the quantity that we want to maximize. The error term $D[Q(z|X)\vert \vert P(z|X)]$ should be close to $0$ if we want some $z$ that is most likely to reproduce the data. On the right hand side is the loss function that we can optimize through gradient descent. The former one is the reconstruction loss to make the generating image looks like the training data, while the latter one is the KL-divergence between the distribution of latent $z$ and normal distribution $N(0,I)$.

The loss function written in python with Keras framework:

def vae_loss(x, x_decoded_mean):
recon_loss = objectives.binary_crossentropy(x, x_decoded_mean)
kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
return recon_loss + kl_loss

However, we have encountered a problem now: how should we get $z$ from the encoder. Obviously, directly sampling won‘t work here since it does not have gradient! Therefore, we have to use a method called reparametrization trick to make it differentiable.

def sample_z(args):
mu, std = args
eps = K.random_normal(shape=(m, n_z), mean=0., std=1.)
return mu + K.exp(std / 2) * eps


The idea behind Conditional VAE is the same as vanilla VAE, it just adds a label input in the input. In such a way, one could use the label and learned latent z to generate any thing he wants.

Take MNIST dataset as an example, one could add a one-hot vector to label the input:


the vector above could be the label for number $3$, and when generating images, one could use the vector above the generating images with written number $3$.


If you want to have a deeper understanding about VAE and have hand-on expirience, here are some resources you can check on: