Variational autoencoders
The goal of this post is to sketch intuitive and clean “Big picture” explanation of the motivation behind variational autoencoders, as well as to provide a detailed walk-through of the derivation from the original paper Auto-Encoding Variational Bayes.
Big picture
Let us have some observed or measured data \(x\), which we assume to be generated by some unknown stochastic process from some “hidden data” (called latent variables) \(z\). Moreover, we assume to know the parametric family of the latent variable distribution \(p_\theta(z)\) as well as of the conditional distribution \(p_\theta(x \vert z)\) for our dataset. To model the conditional distribution \(p_\theta(x \vert z)\) as a function of latent variables \(z\), we use a neural network having \(z\) as input and parameters of the distribution \(p_\theta(x \vert z)\) as the output. Let us note, that weights of this network, as well as parameters determing the actual distribution \(p_\theta(z)\) are both represented by a common vector of parameters \(\theta\).
Knowing the prior latent variable distribution \(p_\theta(z)\) and the conditional distribution \(p_\theta(x \vert z)\), it should be theoretically possible to determine also the posterior distribution \(p_\theta(z \vert x)\). However, as parameters \(p_\theta(x \vert z)\) are result of a possibly complicated neural network, deriving closed-form formula for \(p_\theta(z \vert x)\) is non-trivial. Therefore, we will approximate this posterior distribution also by a neural network, with \(x\) as input and parameters of an approximative posterior distribution \(q_\phi(z \vert x)\) as output. Here, \(\phi\) represents weights of this second neural network and by replacing \(p\) with \(q\) we aknowedge the fact that the conditional distribution is not the one implied by \(p_\theta(z)\) and \(p_\theta(x \vert z)\), but only an approximation. Also note, that we need to choose a parametric family of distributions for \(q_\phi(z \vert x)\).
Now, we have models for the relationship between \(x\) and \(z\) in the form conditional distributions, but how do we find the parameters \(\theta\), \(\phi\) for these models? The answer is rather standard: by maximizing the likelihood of the observed data \(x\). After writing MLE objective that needs to be maximized, \(\log(p_\theta(x))\) is rewritten into algebraically equivalent form that involves also \(p_\theta(x \vert z)\), \(q_\phi(z \vert x)\) and the fully unknown \(p_\theta(z \vert x)\). Because of the last intractable term, part of the final formula cannot be optimized, however it turns out that the remaining part is at least lower bound on the maximum likelihood, therefore we proceed with maximizing at least this lower bound. This objective of the maximization is then simply loss function with opposite sign and the model parameters (network weights) are then searched using backpropagation. Note that part of the lower bound formula also includes expected value, which is approximated by Monte Carlo during training.
And what does this process has to do with autoencoders? Well, after having trained models for \(p_\theta(x \vert z)\) and \(q_\phi(z \vert x)\), we can encode any \(x\) into \(z\), by at first getting the parameters of \(q_\phi(z \vert x)\) by running the appropriate neural network with input \(x\), and then sampling \(z\) from \(q_\phi(z \vert x)\). Also, we can decode \(z\) back into \(x\) using the same approach with \(p_\theta(x \vert z)\). This makes the whole model also an autoencoder, as we can encode observed \(x\) into latent \(z\) and then decode it back into \(x\). Moreover, due to the sampling involved, we can sample several possible encodings \(z\), or decode \(z\) into several possible values of \(x\).
Detailed derivation
Model setting
Following the explanation in the previous section, let us dive into the details of the paper. Lets assume:
- we know the parametric family of distribution \(p_\theta(z)\) for the latent variable \(z\), for example (with no dependence on \(\theta\)):
-
we know the parametric family of the conditional distribution \(p_\theta(x \vert z)\) of the observed variable \(x\) conditioned by the latent variable \(z\):, e.g.:
-
First example - real valued data: multivariate normal distribution:
\[p_\theta(x \vert z) = \mathcal{N}(\mu^p_\theta(z), \Sigma^p_\theta(z))\] -
Second example - binary data: multivariate Bernoulli distribution
\[p_\theta(x \vert z) = \text{Bernoulli}(P_\theta(z))\]
where in the first example \(\mu^p_\theta(z)\), \(\Sigma^p_\theta(z)\) are 2 neural networks or 2 heads of a shared neural network, and in the second example \(P_\theta(z)\) is a neural network with sigmoid activation in the end. The weights of these networks are parameters \(\theta\), and the input for these neural networks is a vector of latent variables \(z\).
-
With nontrivial neural networks used to acquire parameters of the conditional distribution \(p_\theta(x \vert z)\), the posterior distribution \(p_\theta(z \vert x)\) becomes intractable, meaning that it cannot be easily expressed as a closed-form formula. Therefore, we approximate this posterior distribution also with a parametrized distribution \(q_\phi(z \vert x)\) from a chosen family, for example from multivariate normal distribution with diagonal covariance matrix:
\[q_\phi(z \vert x) = \mathcal{N}(\mu^q_\phi(x), diag((\sigma^q_\phi)^2(x)))\]where \(\mu^q_\phi(x)\), \((\sigma^q_\phi)^2(x)\) are 2 neural networks or 2 heads of a shared neural network.
Optimization objective
Now classical objective follows: we want to maximize the likelihood of the observed iid (independent and identically distributed) data \(x^{(1)}, x^{(2)}, \dots, x^{(n)}\):
\[\begin{align} \theta^*, \phi^* &= \arg\max_{\theta, \phi} p(x^{(1)}, x^{(2)}, \dots, x^{(n)}) \\ &= \arg\max_{\theta, \phi} \prod_{i=1}^n p(x^{(i)})\\ &= \arg\max_{\theta, \phi} \sum_{i=1}^n \log(p(x^{(i)}))\\ &= \arg\max_{\theta, \phi} \sum_{i=1}^n \left( D_{KL}(q_\phi(z \vert x^{(i)}) || p_\theta(z \vert x^{(i)})) + \mathcal{L}(\theta, \phi, x^{(i)}) \right) \tag{1}\label{eq:dkl_vlb} \end{align}\]where \(D_{KL}(q_\phi(z \vert x^{(i)}) \vert\vert p_\theta(z \vert x^{(i)}))\) is Kullback-Leibler divergence between the approximate and the true but intractable posterior distribution of \(z\), and \(\mathcal{L}(\theta, \phi, x^{(i)})\) is the so-called variational lower bound (we will get to the reason for the name soon) defined as
\[\begin{align} \mathcal{L}(\theta, \phi, x^{(i)}) &= \mathbb{E}_{q_\phi(z \vert x^{(i)})} \left[ -\log(q_\phi(z \vert x^{(i)})) + log(p_\theta(x^{(i)}, z)) \right] \tag{2}\label{eq:vlb_1}\\ &= -D_{KL}(q_\phi(z \vert x^{(i)}) || p_\theta(z)) + \mathbb{E}_{q_\phi(z \vert x^{(i)})} \left[ \log(p_\theta(x^{(i)} | z)) \right] \tag{3}\label{eq:vlb_2} \end{align}\]Equivalence \eqref{eq:dkl_vlb} can be shown by simple algebraic transformations.
Optimizing the lower bound
As the true posterior \(p_\theta(z \vert x^{(i)})\) is intractable, also the whole \(D_{KL}(q_\phi(z \vert x^{(i)}) \vert\vert p_\theta(z \vert x^{(i)}))\) term cannot be expressed in closed form, and therefore we cannot optimize (search for argmax of) \eqref{eq:dkl_vlb}. However, as the Kullback-Leibler divergence is always non-negative, \(\mathcal{L}(\theta, \phi, x^{(i)})\) is actually a lower bound for each summand in \eqref{eq:dkl_vlb}. Therefore, we can search at least for parameters \(\theta, \phi\) that maximize the sum of these lower bounds, which is then also lower bound for the whole expression \eqref{eq:dkl_vlb}:
\[\tilde\theta, \tilde\phi = \arg\max_{\theta, \phi} \sum_{i=1}^n \mathcal{L}(\theta, \phi, x^{(i)}) = \arg\min_{\theta, \phi} \sum_{i=1}^n - \mathcal{L}(\theta, \phi, x^{(i)}) \tag{4}\label{eq:objective}\]If we are able to analytically express the term \(-D_{KL}(q_\phi(z \vert x^{(i)}) \vert\vert p_\theta(z))\), then it is preffered to use \eqref{eq:vlb_2} as formula for \(\mathcal{L}(\theta, \phi, x^{(i)})\), otherwise \eqref{eq:vlb_1} needs to be used. The last expression inside the argmin is then the global loss function, and for standard minibatch gradient descent based optimization methods the loss function is simply
\[LOSS(\theta, \phi, x^{(i)}) = -\mathcal{L}(\theta, \phi, x^{(i)}) \tag{5}\label{eq:loss}\]Reparametrization trick
In any case, both \eqref{eq:vlb_1} and \eqref{eq:vlb_2} include an expected value expression \(\mathbb{E}_{q_\phi(z \vert x^{(i)})}[\cdot]\). We could approximate this expected value by Monte-Carlo sampling from distribution \(q_\phi(z \vert x^{(i)})\), however, in order to solve \eqref{eq:objective} with gradient descent, we would need to “differentiate the sampling” with respect to parameters \(\phi\), as these parameters determine the sampling itself. This is obviously not possible, therefore the so-called reparametrization trick is needed to express \(z \sim q_\phi(z \vert x^{(i)})\) in a different way:
\[z = g_\phi(x^{(i)}, \epsilon), \quad \epsilon \sim p(\epsilon) \quad \text{ so that } \quad z \sim q_\phi(z \vert x^{(i)})\]Here \(\epsilon\) is a random variable from a known distribution \(p(\epsilon)\) independent of \(\phi\) and \(x^{(i)}\). Such transformation is possible for many distributions, see the original paper of Kingma and Welling for details. By substituting \(g_\phi(x^{(i)}, \epsilon)\) for \(z\), we can rewrite the original expected value
\[\mathbb{E}_{q_\phi(z \vert x^{(i)})}[f(x^{(i)}, z)] = \mathbb{E}_{p(\epsilon)}[f(x^{(i)}, g_\phi(x^{(i)}, \epsilon))]\]and make the sampling in its Monte-Carlo approximation independent of the parameters \(\phi\). In our case, we can use
\[z = g_\phi(x^{(i)}, \epsilon) = \mu^q_\phi(x^{(i)}) + \sigma^q_\phi(x^{(i)}) \odot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\]where \(\odot\) denotes elementwise multiplication.
Monte Carlo sampling
Now we can approximate the expected value in \eqref{eq:vlb_1} or in \eqref{eq:vlb_2} with Monte-Carlo approximation without using sampling method dependent on network parameters:
\[\mathbb{E}_{q_\phi(z \vert x^{(i)})}[f(x^{(i)}, z)] = \mathbb{E}_{p(\epsilon)}[f(x^{(i)}, g_\phi(x^{(i)}, \epsilon))] = \frac{1}{L}\sum_{l=1}^L f(x^{(i)}, g_\phi(x^{(i)}, \epsilon^{(i, l)})), \quad \epsilon^{(i, l)} \sim p(\epsilon)\]For the distributions from our examples \(q_\phi(z \vert x^{(i)})\) and \(p_\theta(z)\), we can optimize the expression \eqref{eq:vlb_2} as the \(D_{KL}\) term can be expressed analytically (see the appendix B fo the original paper for details):
\[D_{KL}(q_\phi(z \vert x^{(i)}) || p_\theta(z)) = -\frac{1}{2} \sum_{j=1}^J \left(1 + \log((\sigma^q_\phi)^2_j(x^{(i)})) - (\mu^q_\phi)^2_j(x^{(i)}) - (\sigma^q_\phi)^2_j(x^{(i)})\right)\]where \(J\) is the latent space dimension. The second summand in \eqref{eq:vlb_2} is the term \(\mathbb{E}_{q_\phi(z \vert x^{(i)})} \left[ \log(p_\theta(x^{(i)} \vert z)) \right]\) that will be estimated using Monte Carlo approximation:
\[\mathbb{E}_{q_\phi(z \vert x^{(i)})} \left[ \log(p_\theta(x^{(i)} | z)) \right] = \frac{1}{L}\sum_{l=1}^L \left[ \log\left(p_\theta(x^{(i)} | z=g_\phi(x^{(i)}, \epsilon^{(i, l)}))\right) \right], \quad \epsilon^{(i, l)} \sim p(\epsilon).\]Note that the parameters \(\phi\) which impact the sampling are also present in the final Monte Carlo approximation. Without using the reparametrization trick, and sampling directly \(z\) from \(q_\phi(z \vert x^{(i)})\) this wouldn’t be the case and the dependence of the sampling on \(\phi\) would be lost during the gradient computation.
Final loss formula for chosen examples
Let us now put the pieces together and express the formula for the minibatch loss \eqref{eq:loss} as a function of arbitrary data sample \(x^{(i)}\) and neural network parameters \(\theta, \phi\) for the examples from the Model setting section:
\[\begin{align} &LOSS(\theta, \phi, x^{(i)}) = -\mathcal{L}(\theta, \phi, x^{(i)})\\ =& D_{KL}(q_\phi(z \vert x^{(i)}) || p_\theta(z)) - \mathbb{E}_{q_\phi(z \vert x^{(i)})} \left[ \log(p_\theta(x^{(i)} | z)) \right] \\ \simeq& - \frac{1}{2} \sum_{j=1}^J \left(1 + log((\sigma^q_\phi)^2_j(x^{(i)})) - (\mu^q_\phi)^2_j(x^{(i)}) - (\sigma^q_\phi)^2_j(x^{(i)})\right) \\ &- \frac{1}{L}\sum_{l=1}^L \left[ \log\left(p_\theta(x^{(i)} | z=g_\phi(x^{(i)}, \epsilon^{(i, l)}))\right) \right] \\ =& - \frac{1}{2} \sum_{j=1}^J \left(1 + log((\sigma^q_\phi)^2_j(x^{(i)})) - (\mu^q_\phi)^2_j(x^{(i)}) - (\sigma^q_\phi)^2_j(x^{(i)})\right) \\ &- \frac{1}{L}\sum_{l=1}^L \left[ \log\left(p_\theta(x^{(i)} | z=\mu^q_\phi(x^{(i)}) + \sigma^q_\phi(x^{(i)}) \odot \epsilon^{(i, l)})\right) \right], \quad \epsilon^{(i, l)} \sim \mathcal{N}(0, I) \end{align}\]Here, in binary data \(x^{(i)}\) case (first example), with \(M\) being the dimension of a data sample
\[p_\theta(x^{(i)}|z) = \prod_{m=1}^M P_\theta(z) x^{(i)}_m + (1-P_\theta(z)) (1-x^{(i)}_m)\]and in real normally-distributed data \(x^{(i)}\) case (second example)
\[p_\theta(x^{(i)}|z) = f_{\mathcal{N}(\mu^p_\theta(z), \Sigma^p_\theta(z))}(x^{(i)})\]where \(f_{\mathcal{N}(\mu^p_\theta(z), \Sigma^p_\theta(z))}\) is the probability density function of the multivariate normal distribution \(\mathcal{N}(\mu^p_\theta(z), \Sigma^p_\theta(z))\).