1. Latent Variable Models

Probabilistic models try to model some sort of probability distribution. For example, we may wish to model the distribution \(p(x)\) from which a set of points \(\{x_1,\ldots,x_N\}\) were drawn. Or, if we have both the inputs \(x\) and the targets \(y\), we may want to model the conditional distribution \(p(y \vert x)\). For now, we will only concern ourselves with the former case of modelling \(p(x)\).

Latent variable models assume that the distribution \(p(x)\) can be constructed by applying some non-linear complex transformation on a relatively simpler (e.g. Gaussian) distribution \(p(z)\). We, thus, model \(p(x \vert z)\) and use it to find \(p(x)\):

$$ p(x) = \int p(x \vert z)p(z)dz $$

One example of \(p(x \vert z)\) is \(\mathcal{N}(\mu_{nn}(z),\sigma_{nn}(z))\) where both \(\mu_{nn}\) and \(\sigma_{nn}\) are some neural networks with inputs \(z\). \(z\) itself can be drawn from some Gaussian distribution with fixed parameters.

Let \(\theta\) denote the parameters of model for \(p(x \vert z)\). Given a dataset \(\{x_1,\ldots,x_N\}\) our goal is to find \(\theta\) such that:

$$ \begin{align} \theta &\leftarrow \text{argmax}_{\theta} \frac{1}{N}\sum_{i=1}^N \log p_\theta(x^{(i)})\\ &\leftarrow \text{argmax}_{\theta} \frac{1}{N}\sum_{i=1}^N \log \int p_\theta(x^{(i)} \vert z)p(z)dz \end{align} $$

The issue, however, is that the integral in the equation above is completely intractable. The remainder of this set of lecture notes is dedicated to finding a way to approximate this integral.

2. Digression

Before we proceed further, let us talk about two important concepts: entropy and KL divergence.

A. Entropy

The entropy \(\mathcal{H}\) of a probability distribution \(p\) of some random variable \(X\) that takes on values \(x\) is given as:

$$ \mathcal{H}(p) = -\int p(x)\log p(x) dx = \mathbb{E}_{x \sim p(x)}\left[\log p(x) \right] $$

Entropy is a measure of the randomness in the distribution. If the entire probability mass (of \(1\)) was to be assigned to a single value of \(X\), say \(x'\), the entropy would be zero since at \(x'\) we have \(p(x')\log p(x')\) \(=\)\(1\cdot \log 1 = 0\) and at \(x\neq x'\) we have \(0\cdot \log 0 = 0\) (the latter can be proved through L’ Hôpital’s rule).

As an example, the wider a Gaussian distribution is, the higher is its entropy.

B. KL Divergence

The KL divergence between two distributions \(p\) and \(q\) of some random variable \(X\) that takes on values \(x\) is given as:

$$ \begin{align} D_{KL}(p \vert\vert q) &= \mathbb{E}_{x \sim p(x)}\left[\log\frac{p(x)}{q(x)}\right]\\ &= \mathbb{E}_{x \sim p(x)}\left[\log p(x)\right] - \mathbb{E}_{x \sim p(x)}\left[\log q(x)\right]\\ &= -\mathcal{H}(p) - \mathbb{E}_{x \sim p(x)}\left[\log q(x)\right] \end{align} $$

Suppose that we want to choose \(p\) so as to minimize its KL divergence with \(q\). Let us write the second term above as:

$$ -\mathbb{E}_{x \sim p(x)}\left[\log q(x)\right] = -\int p(x)\log q(x)dx $$

It can easily be seen that this term is minimized if we choose \(p(x)\) such that it is \(1\) when \(x =\) \(\text{argmax}_x q(x)\) and \(0\) otherwise.1 In other words, minimizing this term has the effect of concentrating the entire probability mass of \(p\) onto the value of \(x\) that has the highest probability under \(q\).

However, to also minimize the first term, we need to maximize the entropy of \(p\). As discussed in the previous subsection, maximizing the entropy is the same as making \(p\) wider or more spread out. This phenomenon is depicted in the following figure:

One final detail about the KL divergence is that it is always positive:

$$ \begin{align} D_{KL}(p \vert\vert q) &= \mathbb{E}_{x \sim p(x)}\left[\log\frac{p(x)}{q(x)}\right]\\ &= \mathbb{E}_{x \sim p(x)}\left[-\log\frac{q(x)}{p(x)}\right]\\ &\geq \log\mathbb{E}_{x \sim p(x)}\left[\frac{q(x)}{p(x)}\right]\\ &= \log \int p(x)\frac{q(x)}{p(x)}dx\\ &= \log \int q(x)dx\\ &= \log 1\\ &= 0 \end{align} $$

where the third line follows from Jenson’s inequality2.

3. The Variational Approximation

Let us approximate \(p(z \vert x^{(i)})\) with a some distribution \(q_i(z)\). Note that we have a separate distribution for each data point \(x^{(i)}\). \(q_i\) could be as simple as a Gaussian with mean \(\mu_i\) and variance \(\sigma_i\). Using Jenson’s inequality2 we can bound \(p(x^{(i)})\) as follows:

$$ \begin{align} \log p(x^{(i)}) &= \log \int p(x^{(i)} \vert z)p(z)dz\\ &= \log \int p(x^{(i)} \vert z)p(z)\frac{q_i(z)}{q_i(z)}dz\\ &= \log \mathbb{E}_{z \sim q_i(z)}\left[\frac{p(x^{(i)} \vert z)p(z)}{q_i(z)}\right]\\ &\geq \mathbb{E}_{z \sim q_i(z)}\left[\log\frac{p(x^{(i)} \vert z)p(z)}{q_i(z)}\right]\\ &= \mathbb{E}_{z \sim q_i(z)}\left[\log p(x^{(i)} \vert z) + \log p(z)\right] - \mathbb{E}_{z \sim q_i(z)}\left[\log q_i(z)\right]\\ &= \mathbb{E}_{z \sim q_i(z)}\left[\log p(x^{(i)} \vert z) + \log p(z)\right] + \mathcal{H}(q_i)\\ \end{align} $$

Let us denote the right hand side of the inequality above with \(\mathcal{L}_i(p,q_i)\). Note that we may interpret maximizing \(\mathcal{L}_i\) in the same way as we interpreted minimizing the KL divergence in the previous section.

Let us write down the KL divergence between \(q_i(z)\) and \(p(z \vert x_i)\):

$$ \begin{align} D_{KL}(q_i(z) \vert\vert p(z \vert x^{(i)})) &= \mathbb{E}_{z \sim q_i(z)}\left[\log \frac{q_i(z)}{p(z\vert x^{(i)})}\right]\\ &= \mathbb{E}_{z \sim q_i(z)}\left[\log \frac{q_i(z)p(x^{(i)})}{p(x^{(i)}\vert z)p(z)}\right]\\ &= \mathbb{E}_{z \sim q_i(z)}\left[\log q_i(z)\right]+\mathbb{E}_{z \sim q_i(z)}\left[ \log p(x^{(i)})\right]-\mathbb{E}_{z \sim q_i(z)}\left[\log p(x^{(i)}\vert z)+\log p(z)\right]\\ &=-\mathcal{H}(q_i) +\log p(x^{(i)})\mathbb{E}_{z \sim q_i(z)}[1]-\mathbb{E}_{z \sim q_i(z)}\left[\log p(x^{(i)}\vert z)+\log p(z)\right]\\ &= -\mathcal{L}_i(p,q_i)+\log p(x^{(i)}) \end{align} $$

Note that as the KL divergence is always positive, we again have that:

$$ \begin{align} \log p(x^{(i)}) &= D_{KL}(q_i(z) \vert\vert p(z \vert x^{(i)}))+\mathcal{L}_i(p,q_i)\\ &\geq \mathcal{L}_i(p,q_i) \end{align} $$

Note that maximizing \(\mathcal{L}_i\) with respect to \(q_i\) minimizes the KL divergence between \(q_i(z)\) and \(p(z \vert x^{(i)})\). Similarly maximizing it with respect to \(p\) will maximize \(\log p(x^{(i)})\). Therefore, we can simply choose \(\theta\) (recall that \(\theta\) are the parameters of a neural network that models \(p(x^{(i)}\vert z)\)) such that:

$$ \theta \leftarrow \text{argmax}_\theta \frac{1}{N}\sum_{i=1}^N \mathcal{L}_i(p,q_i) $$

This can be done in the following way:

For each \(x^{(i)}\):

  1. Calculate \(\nabla_\theta \mathcal{L}_i(p,q_i) =\nabla_\theta \mathbb{E}_{z \sim q_i(z)}\left[\log p_\theta(x^{(i)} \vert z)\right]\)
  2. Update \(\theta \leftarrow \theta + \alpha\nabla_\theta \mathcal{L}_i(p,q_i)\)
  3. Update \(q_i\) to maximize \(\mathcal{L}_i(p,q_i)\)

We can approximate the expectation in step 1 with a single sample i.e.:

  1. Sample \(z \sim q_i(x^{(i)})\)
  2. Approximate \(\nabla_\theta \mathcal{L}(p,q_i) \approx \nabla_\theta p_\theta(x^{(i)}\vert z)\)

If \(q_i(z)\) is a Gaussian with mean \(\mu_i\) and variance \(\sigma_i\), then we can perform step 3 by computing \(\nabla_{\mu_i}\mathcal{L}_i(p,q_i)\) and \(\nabla_{\sigma_i}\mathcal{L}_i(p,q_i)\) and simply doing gradient ascent on \(\mu_z\) and \(\sigma_z\).

4. Amortized Variational Approximation

One major issue with the method presented in the previous section is that we have a separate distribution for each point in our dataset. This means that we have a total of \(\vert \theta \vert+(\vert\mu_i\vert+\vert\sigma_i\vert)\cdot N\) parameters. One way to avoid having so many parameters is to train a single neural network with parameters \(\phi\) to approximate \(p(z \vert x^{(i)})\).

Given an input \(x^{(i)}\) the network, denoted with \(q_\phi\), can, for example, output the mean and variance of a Gaussian distribution. This distribution approximates \(p(z \vert x^{(i)})\). We can then simply sample some \(z\) from this distribution and use it to calculate \(\nabla_\theta\mathcal{L}_i(p,q_i)\) and update \(\theta\).

To update \(\phi\) we need to compute \(\nabla_\phi \mathcal{L}_i\). We can write \(\mathcal{L}_i\) as:

$$ \mathcal{L}_i(p,q_\phi) = \mathbb{E}_{z \sim q_\phi(z\vert x^{(i)})}\left[\log p(x^{(i)} \vert z) + \log p(z)\right] + \mathcal{H}(q_\phi) $$

The gradient of \(\phi\) with respect to the second term can be derived easily.3 To compute the gradients with respect to \(\phi\) for the first term (which we denote with \(J(\phi)\)), we first need to calculate the derivative with respect to \(z\) and then calculate the derivative of \(z\) with respect to \(\phi\). However \(z\) was sampled from a probability distribution (and is thus a stochastic quantity) and so we can not compute its derivative with respect to \(\phi\).

One way around this is to note that \(J(\phi)\) the same form that we had with policy gradients (in which we had the expectation of the sum of rewards under some policy distribution). Just as we did there, we can approximate the gradient of this expectation with:

$$ \nabla_\phi J(\phi) \approx \frac{1}{M}\sum_{j=1}^M \nabla_\phi \log q_\phi(z_j\vert x^{(i)})\left(\log p(x^{(i)} \vert z) + \log p(z)\right) $$

However, this has the same problem that the policy gradients had: high variance. It turns out that there is a better way to approximate this gradient which is the reparameterization trick. This trick rewrites \(z\) as follows (we assume that \(q_\phi\) outputs the mean \(\mu_\phi\) and variance \(\sigma_\phi\) of a Gaussian distribution):

$$ z = \mu_\phi + \epsilon\sigma_\phi $$

where \(\epsilon \sim \mathcal{N}(0,1)\). So while \(z\) is still a random variable, we can compute its derivative with respect to \(\phi\) (as the stochastic quantity i.e. \(\epsilon\) in its equation is independent of \(\phi\)). Therefore, to approximate \(\nabla_\phi J(\phi)\), we can:

  1. Sample \(\epsilon_1,\ldots,\epsilon_M\) from \(\mathcal{N}(0,1)\) (often a single sample works fine)
  2. Approximate \(\nabla_\phi J(\phi) \approx \frac{1}{M}\sum_{i=1}^M\nabla_\phi \log p(x^{(i)} \vert \mu_\phi + \epsilon_i \sigma_\phi)\)4

The main drawback of the reparameterization trick is that it can only handle continuous latent variables whereas policy gradients can handle both continuous and discrete latent variables. However, unlike policy gradients, the reparameterization trick has lower variance and are easy to implement.

To summarize, we feed \(x^{(i)}\) to the neural network \(q_\phi\) which outputs \(\mu_\phi\) and \(\sigma_\phi\). We sample \(\epsilon\) from a Gaussian distribution with mean \(0\) and variance \(1\) and compute \(z\) using the reparameterization trick. Finally, we feed \(z\) to the neural network \(p_\theta\) which models the distribution \(p(x^{(i)}\vert z)\).

5. The Variational Autoencoder

The algorithm described at the end of the previous section is known as the variational autoencoder. Concretely, the variational autoencoder has the following objective function:

$$ \max_{\phi,\theta} \log p(x^{(i)} \vert \mu_\phi + \epsilon \sigma_\phi)-D_{KL}(q_\phi(z\vert x^{(i)})\vert\vert p(z)) $$

6. Conditional Models

In case we are also given the labels \(y\) we can simply rewrite \(\mathcal{L}_i\) as (since we want to model \(p(y^{(i)}\vert x^{(i)})\)):

$$ \mathcal{L}_i(p,q_\phi) = \mathbb{E}_{z \sim q_\phi(z\vert x^{(i)},y^{(i)})} \left[ \log p(y^{(i)} \vert x^{(i)},z) + \log p(z\vert x^{(i)})\right] + \mathcal{H}(q_\phi(z\vert x^{(i)},y^{(i)})) $$

  1. This follows from the fact that for some positive \(\alpha_i\) where \(\sum_i \alpha_i = 1\), we always have: \(\max_{x_i} f(x_i) \geq \sum \alpha_i f(x_i)\). 

  2. Note that \(\log\) is a concave function as \(d^2[\log x]/dx^2 = -1/x^2 \leq 0\) so \(\mathbb{E}[\log x] \leq \log \mathbb{E}[x]\).  2

  3. For example, the entropy of a Gaussian distribution with variance \(\sigma\) can be shown to be equal to \(\frac{1}{2}\ln(2\pi e\sigma^2)\) (it does not depend on the mean). Computing the gradient with respect to \(\sigma\) is thus trivial. 

  4. Here we have ignored \(\log p(z)\). This is because \(\mathbb{E}_{z \sim q_\phi(z\vert x^{(i)})}\left[\log p(z)\right] + \mathcal{H}(q_\phi)\) is simply equal to \(-D_{KL}(q_\phi(z\vert x^{(i)})\vert\vert p(z))\) which often has a convenient analytical solution.