Generative Models 2 - Variational Methods

← A Linear ModelNormalizing Flows →

In the previous post we saw how to define $x=G_{\theta}\left(z\right)$ such that $p\left(x;\theta\right)$ is tractable (can be calculated easily). But what if we wanted to use more general decoders? For instance, we believe that the latent space is in a lower dimension, but that the decoder is not linear. What do we do in that case?

If we want to calculate $\log p_{\theta}\left(x\right)$ in general models, we’re going to have to start approximating some things. After all, calculating:

\[\begin{equation} \log p_{\theta}\left(x\right)=\log\intop p_{\theta}\left(x\vert z\right)p\left(z\right)dz \end{equation}\]

is going to be impossible in almost every scenario.

Variational methods, the focus of this section, attempt to lower-bound the log-likelihood instead of calculating it exactly. Recall, our goal is to maximize the expected log-likelihood:

\[\begin{equation} \text{goal:}\;\text{find }\arg\max_{\theta}\mathbb{E}_{p_{\text{data}}}\left[\log p\left(x;\theta\right)\right] \end{equation}\]

If we can find a lower bound $\log p\left(x;\theta\right)\ge\phi\left(x;\theta\right)$ and maximize $\phi\left(x;\theta\right)$, then we know that our distribution will be at least as good as that.


Evidence Lower Bound (ELBO)

Finding a lower bound for the log-likelihood sounds kind of difficult, but follows a set of simple steps. To find the lower bound, we will need a guess for $p_{\theta}\left(z\vert x\right)$. Let’s call this guess $q_{\phi}\left(z\vert x\right)$. The log-likelihood can now be written as:

\[\begin{align} \log p_{\theta}\left(x\right) & =\log\intop p_{\theta}\left(x,z\right)dz\\ & =\log\intop p_{\theta}\left(x,z\right)\frac{q_{\phi}\left(z\vert x\right)}{q_{\phi}\left(z\vert x\right)}q_{\phi}\left(z\vert x\right)dz\\ & =\log\mathbb{E}_{q_{\phi}}\left[\frac{p_{\theta}\left(x,z\right)}{q_{\phi}\left(z\vert x\right)}\right]\\ & \ge\mathbb{E}_{q_{\phi}}\left[\log p_{\theta}\left(x\vert z\right)+\log\frac{p\left(z\right)}{q_{\phi}\left(z\vert x\right)}\right]\\ & =\mathbb{E}_{q_{\phi}}\left[\log p_{\theta}\left(x\vert z\right)\right]-D_{\text{KL}}\left(q_{\phi}\left(z\vert x\right)\vert \vert p\left(z\right)\right)\\ & =-F\left(x;\theta,\phi\right)\label{eq:free-energy} \end{align}\]

The last step is due to Jensen’s inequality. Notice how this is true for any $q_{\phi}\left(z\vert x\right)$ - we now have a lower bound on the log-likelihood! The value $F\left(x;\theta,\phi\right)$ is sometimes called the free energy for reasons that are broadly trivia and specifically way too complicated to actually talk about. At any rate, it’s useful to know that it’s called the free energy, so you won’t be surprised if you see the name ever again.

When is this lower bound tight? If we happen to choose $q_{\phi}\left(z\vert x\right)=p_{\theta}\left(z\vert x\right)$, then:

\[\begin{align} \log p_{\theta}\left(x\right) & \ge\mathbb{E}_{q_{\phi}\left(z\vert x\right)}\left[\log\frac{p_{\theta}\left(z\vert x\right)}{q_{\phi}\left(z\vert x\right)}+\log p_{\theta}\left(x\right)\right]\\ & =\mathbb{E}_{p_{\theta}\left(z\vert x\right)}\left[\log\frac{p_{\theta}\left(z\vert x\right)}{p_{\theta}\left(z\vert x\right)}+\log p_{\theta}\left(x\right)\right]\\ & =\mathbb{E}_{p_{\theta}\left(z\vert x\right)}\left[\log p_{\theta}\left(x\right)\right]=\log p_{\theta}\left(x\right) \end{align}\]

So, the best guess we can have for $q_{\phi}\left(z\vert x\right)$ is the conditional distribution we are approximating to begin with, $p_{\theta}\left(z\vert x\right)$.

In practice, we will assume a simple form for $q_{\phi}\left(z\vert x\right)$, almost always a Gaussian distribution. The quality of the lower bound will then be the difference from the guess $q_{\phi}\left(z\vert x\right)$ and the true posterior $p_{\theta}\left(z\vert x\right)$.


Classical Variational Inference (VI)

What we wrote above in equation \eqref{eq:free-energy} is a lower for the log-likelihood of a single data point. Remember, what we actually want to maximize is the log-likelihood of all training examples. The “traditional” way to use this lower bound for training a model is to alternate the following steps:

\[\begin{align} \left(I\right)\qquad & \text{for }i=1,\cdots,N:\quad\hat{\phi}_{i}=\arg\max_{\phi}-F\left(x_{i};\theta^{\left(t-1\right)},\phi\right)\\ \left(II\right)\qquad & \theta^{\left(t\right)}=\arg\max_{\theta}\left\{ -\frac{1}{N}\sum_{i=1}F\left(x_{i};\theta,\hat{\phi}_{i}\right)\right\} \end{align}\]

The first step gives us a better lower bound for the true log-likelihood. The second step is where we make a better prediction of the parameters $\theta$ under the specific lower bound.

Notice that step $\left(I\right)$ basically requires an optimization of a distribution for every observed data point $x_{i}$. This is quite expensive. However, as a byproduct we get, for the same price, the approximation $q\left(z\vert x_{i};\hat{\phi}_{i}\right)$ for each posterior $p\left(z\vert x_{i};\hat{\theta}\right)$. This can sometimes be useful (maybe more on that later on).

Example

I’ve kept everything pretty vague until now, so an example might be good right about now.

Let’s, for a moment, imagine that we don’t know how to find the posterior distribution for the pPCA. Our model is:

\[\begin{equation} p_{\theta}\left(x,z\right)=\mathcal{N}\left(z\vert \;0,I\right)\times\mathcal{N}\left(x\vert \;\mu+Wz,\;I\varphi^{2}\right) \end{equation}\]

Our parameters are $\theta=\left\{ \mu,W,\varphi\right\}$. We are assuming that we don’t know the posterior distribution, so let’s approximate it with some other distribution. The simplest distribution to use many times is a isotropic Gaussian, so let’s define:

\[\begin{equation} \phi=\left\{ m,\sigma^{2}\right\} \qquad q_{\phi}\left(z\vert x\right)=\mathcal{N}\left(z\vert \:m,I\sigma^{2}\right)\approx p_{\theta}\left(z\vert x\right) \end{equation}\]

Now, for each data point $x_{i}$, we will try to optimize $\phi_{i}$ so the following is as high as possible:

\[\begin{align} \hat{\phi}\left(x_{i}\right) & =\arg\max_{\phi}-F\left(x_{i};\theta,\phi\right)\\ & =\arg\max_{\phi}\left\{ \mathbb{E}_{q_{\phi}}\left[\log p_{\theta}\left(x_{i}\vert z\right)\right]- D_{\text{KL}}\left(q_{\phi}\left(z\vert x_{i}\right)\vert \vert p\left(z\right)\right)\right\} \end{align}\]

This seems quite difficult to calculate. Luckily for us, both $p\left(z\right)$ and $q_{\phi}\left(z\vert x_{i}\right)$ are Gaussian distributions, so the KL-divergence between the two has a closed-form expression:

\[\begin{align} D_{\text{KL}}\left(q_{\phi}\left(z\vert x_{i}\right)\vert \vert p\left(z\right)\right) & =\frac{1}{2}\left(\text{trace}\left(I\sigma^{2}\right)-\text{dim}\left(z\right)+\| m\| ^{2}-\log\left\vert I\sigma^{2}\right\vert \right)\\ & =\frac{\text{dim}\left(z\right)}{2}\left(\sigma^{2}-\log\sigma^{2}\right)+\frac{1}{2}\| m\| ^{2} \end{align}\]

So that’s one part of the lower bound we can directly calculate. How about the expectation? Well, in this case there is also a closed-form expression for the expectation, but many times there won’t be. Instead, we can try to approximate $\mathbb{E}_{q_{\phi}}\left[\log p_{\theta}\left(x_{i}\vert z\right)\right]$ using, you guessed it, Monte Carlo (MC) samples. Basically, we will draw $M$ samples from $z_{j}\sim q_{\phi}\left(z\vert x_{i}\right)$ and approximate the expectation using these samples:

\[\begin{equation} \mathbb{E}_{q_{\phi}}\left[\log p_{\theta}\left(x_{i}\vert z\right)\right]\approx\frac{1}{M}\sum_{j:\,z_{j}\sim q_{\phi}\left(z\vert x_{i}\right)}^{M}\log p_{\theta}\left(x_{i}\vert z_{j}\right) \end{equation}\]

Putting all of that together, for each data point $x_{i}$ we will try to find $\hat{\phi}_{i}=\left\{ m_{i},\sigma_{i}^{2}\right\}$ that maximizes the following:

\[\begin{align} \hat{\phi}_{i}$=\arg\max_{\phi}\left\{ \frac{1}{M}\sum_{j:\,z_{j}\sim q_{\phi}\left(z\vert x_{i}\right)}^{M}\log p_{\theta}\left(x_{i}\vert z_{j}\right) \right. \\ \left. -\frac{\text{dim}\left(z\right)}{2}\left(\sigma^{2}- \log\sigma^{2}\right)+\frac{1}{2}\| m\| ^{2}\right\} \end{align}\]

If we do this for each $x_{i}$, then our lower-bound for the expected log-likelihood of the whole data will be:

\[\begin{align} \frac{1}{N}\sum_{i=1}^{N}\log p_{\theta}\left(x_{i}\right) & \ge\frac{1}{N}\sum_{i=1}^{N}\left[\frac{1}{M}\sum_{j:\,z_{j}\sim q_{\hat{\phi}_{i}}\left(z\vert x_{i}\right)}^{M}\log p_{\theta}\left(x_{i}\vert z_{j}\right)\right. \\& \left. -\frac{\text{dim}\left(z\right)}{2}\left(\sigma_{i}^{2}-\log\sigma_{i}^{2}\right)+\frac{1}{2}\| m_{i}\| ^{2}\right] \end{align}\]

If that seems to you like a really round about way to get a lower bound for the log-likelihood then, well, I don’t blame you.


Variational Auto-Encoders (VAEs)

Variational auto-encoders (VAEs, ) attempt to do the above in a way that is slightly more efficient. Instead of optimizing the parameters $\phi$ for each data point $x_{i}$ individually, an encoder is trained to try and predict them. At the same time, the mapping from $\mathcal{Z}$ to $\mathcal{X}$ is also trained.

Concretely, suppose $z\sim\mathcal{N}\left(0,I\right)$. Because there’s a closed-form expression for the KL-divergence between two Gaussians, and because it’s easy to sample from a Gaussian, it will be convenient if we assume that:

\[\begin{equation} q_{\phi}\left(z\vert x\right)=\mathcal{N}\left(z\vert \:\mu_{\phi}\left(x\right),\Sigma_{\phi}\left(x\right)\right) \end{equation}\]

In other words, our guess for the posterior $p_{\theta}\left(z\vert x\right)$ is a Gaussian distribution whose mean and covariance are functions of the observed data, $x$. In practice, a neural network (NN) will be used to encode $x$ into the mean $\mu_{\phi}\left(x\right)$ and covariance $\Sigma_{\phi}\left(x\right)$. Now, instead of finding $\hat{\phi}_{i}$ as we did in VI before, we’ll just try to train the encoder $\mu_{\phi}\left(x_{i}\right)$ and $\Sigma_{\phi}\left(x_{i}\right)$ to give a good guess for the posterior.

In other words, we’ll train both the encoders and the decoder, $G_{\theta}\left(z\right)$, at the same time using the usual variational loss we saw in equation \eqref{eq:free-energy}:

\[\begin{equation} \left\{ \hat{\theta},\hat{\phi}\right\} =\arg\max_{\theta,\phi}\left\{ \frac{1}{N}\sum_{i=1}^{N}\left[\mathbb{E}_{q_{\phi}}\left[\log p_{\theta}\left(x_{i}\vert G_{\theta}\left(z\right)\right)\right]-D_{\text{KL}}\left(q_{\phi}\left(z\vert x_{i}\right)\vert \vert p\left(z\right)\right)\right]\right\} \end{equation}\]

As we mentioned earlier, we know how to calculate $D_{\text{KL}}\left(q_{\phi}\left(z\vert x_{i}\right)\vert \vert p\left(z\right)\right)$ exactly:

\[\begin{equation} D_{\text{KL}}\left(q_{\phi}\left(z\vert x_{i}\right)\vert \vert p\left(z\right)\right)=\frac{1}{2}\left(\text{trace}\left[\Sigma_{\phi}\left(x_{i}\right)\right]+\| \mu_{\phi}\left(x_{i}\right)\| ^{2}-\log\left\vert \Sigma_{\phi}\left(x_{i}\right)\right\vert -\text{dim}\left(z\right)\right) \end{equation}\]

which will be pretty simple to calculate on the fly.

But we are missing other ingredients. For instance, how should we define the observation model $p_{\theta}\left(x\vert G_{\theta}\left(z\right)\right)$? And how do we calculate (and back-propagate through) the expectation of the first term?

In a second we’ll get to what people usually use for the observation model, but for now let’s leave it up to the user. Having defined $p_{\theta}\left(x\vert G_{\theta}\left(z\right)\right)$, we’re still left with the question of how to calculate the expectation. What is usually done is the most straightforward - simply use an MC approximation. Actually, approximate the expectation with a single sample from $q_{\phi}\left(z\vert x_{i}\right)$:

\[\begin{equation} \mathbb{E}_{q_{\phi}}\left[\log p_{\theta}\left(x_{i}\vert G_{\theta}\left(z\right)\right)\right]\approx\log p_{\theta}\left(x_{i}\vert G_{\theta}\left(\tilde{z}\right)\right)\quad\tilde{z}\sim\mathcal{N}\left(\mu_{\phi}\left(x_{i}\right),\Sigma_{\phi}\left(x_{i}\right)\right) \end{equation}\]

While this sounds kind of ridiculous, it actually works okay.

Observation Model

The most common observation model is, as you probably already guessed, simply a Gaussian distribution:

\[\begin{equation} \log p_{\theta}\left(x\vert G_{\theta}\left(z\right)\right)=-\frac{\beta}{2}\| x-G_{\theta}\left(z\right)\| ^{2}+\frac{\text{dim}\left(x\right)}{2}\log\beta+\text{const} \end{equation}\]

where $1/\beta$ is the variance of the observation model. In vanilla VAEs, $\beta$ is almost ubiquitously set to 1. If we think about the whole story so far, this is like assuming that every example has been observed with Gaussian noise whose variance is 1. This might make sense some times, but in images (for instance) where pixels take values between 0 and 1, maybe it doesn’t make sense? Just something to think about.

Anyway, setting $\beta=1$, we get the “standard” VAE loss (ignoring constants and multiplicative factors, as is usually done):

\[\begin{equation} L\left(\theta,\phi\right)=\underbrace{\| x-G_{\theta}\left(z\right)\| ^{2}}_{\text{reconstruction}}+\underbrace{\text{trace}\left[\Sigma_{\phi}\left(x_{i}\right)\right]+\| \mu_{\phi}\left(x_{i}\right)\| ^{2}-\log\left\vert \Sigma_{\phi}\left(x_{i}\right)\right\vert }_{\text{KL term}} \end{equation}\]

If we use a different value for $\beta$, we recover something called $\beta$-VAEs (it’s kind of silly that it has a different name):

\[\begin{equation} L_{\beta}\left(\theta,\phi\right)=\beta\cdot\| x-G_{\theta}\left(z\right)\| ^{2}+\text{trace}\left[\Sigma_{\phi}\left(x_{i}\right)\right]+\| \mu_{\phi}\left(x_{i}\right)\| ^{2}-\log\left\vert \Sigma_{\phi}\left(x_{i}\right)\right\vert \end{equation}\]


Conclusion

VAEs, as described in this post, are very general. We didn’t impose any conditions on the decoder (or the encoder) in any real way, and in theory they should work quite well. In practice, however, (when used in vision) VAEs tend to produce very blurry images. For a long time, the observation model was blamed for these subpar resultsI'm happy those concerns have faded away after the success of diffusion models., but it’s not really clear why they underperform in the task of generation. As a consequence of this bad performance, many additions were added onto the vanilla VAE. These include using a more expressive variational distribution ($q_{\phi}(z)$, ), adding more MCMC samples with appropriate weights in the ELBO calculation (), and changing the observation model (). None of these seem to help in a meaningful way, as far as I know ().

While VAEs are kind of unpopular at the moment, they are still a very good introduction into generative models. Moreover, they are still used to some extent - the most ubiquitous use is as auto-encoders for other generative models (). The fact that they inherently give a lower bound for the likelihood of newly observed samples is also a major plus.


← A Linear ModelNormalizing Flows →