EM Algorithm Recap

  11 minute read

\[\newcommand{\argmin}{\mathop{\mathrm{argmin}}} \newcommand{\argmax}{\mathop{\mathrm{argmax}}} \renewcommand{\vec}[1]{\boldsymbol{#1}}\]

Introduction

This post explains Expectation-Maximization (EM) algorithm from scratch in a fairly concise fashion. The material is based on my own notes, which of course come from a variety of great resources online that are listed in the references section.

EM is one of the most elegant and widely used machine learning algorithms but is sometimes not thoroughly introduced in introductory machine learning courses. What is so elegant about EM is that, as we shall see, it originates from nothing but the most fundamental laws of probability.

Many variants of EM have been developed, and an important class of statistical machine learning methods called variational inference also has a strong connection to EM. The core ideas and derivatives of EM find many applications in both classical statistical machine learning and models that involve deep neural networks, making it worthwhile to have an intuitive and thorough understanding of it, which is what this post attempts to provide.

Notation

  • Random variables \(X\), probability distribution \(P(X)\)
  • Probability density function (PDF) \(p(\cdot)\), evaluated at value \(x\): \(p(X=x)\) with \(p(x)\) as a shorthand
  • PDF with parameter \(\theta\) is noted as \(p_\theta(x)\) or equivalently \(p(x\vert \theta)\)
  • Expectation of \(f(x)\) according to distribution \(P\): \(\mathbb{E}_{x\sim P}\left[f(x)\right]\)
  • A set is noted as \({x_i}\) or calligraphic letter \(\mathcal X\)

Maximum likelihood

Supposed we had data coming from a distribution \(P_D(X)\), and we want to come up with a model for \(x\) parameterized by \(\theta\): \(p(x;\theta)\) or equivalent noted as \(p_{\theta}(x)\) to best approximate the real data distribution. Further assume all the data samples are independent and identically distributed (iid) with \(P_D(X)\).

To find \(\theta\) under a maximum likelihood scheme we do

\[\begin{equation} \begin{split} \hat{\theta}_{MLE} &= \argmax_{\theta} \ell(\theta) \\\\ &= \argmax_{\theta} \sum_{i} \log\left( p_{\theta}(x_i) \right) \end{split} \end{equation}\]

Motivation for EM

We might encounter situations where, in addition to observed data \({x_i}\), we have missing or hidden data \({z_i}\). It might literally be data that is missing for some reason. Or, more interestingly, it might be due to our modeling choice. We might prefer to have a model with a set of meaningful but hidden variables \({z_i}\) that help explain the “causes” of \({x_i}\). Good examples of this category would be Gaussian (or other kind of) mixture models, and LDA.

Note to myself: examples when we introduces latent variables just for the sake of making the optimization problem easier?

In either case, we will need to have a model for calculating the joint distribution of \(x\) and \(z\), \(p(x,z;\theta)\), which may arise from assumptions (in the case of missing data) or from models of marginal density functions \(p(z; \theta)\) and \(p(x\vert z; \theta)\). In such cases, the log likelihood can be expressed as

\[\begin{equation} \begin{split} \ell(\theta) &= \sum_i \log\left( p_{\theta}(x_i) \right)\\\\ &= \sum_i \log\left( \sum_{z} p_{\theta}(x_i, Z=z) \right)\\\\ &= \sum_i \log\left( \sum_{z} p_{\theta}(x_i\vert Z=z)p_{\theta}(Z=z) \right) \end{split} \end{equation}\]

Direct maximization of with respect to \(\theta\) might be challenging, due to the summation over \(z\) inside the log. But the problem would be much easier if we knew the values of \(z\). It is simply the original maximum likelihood problem with all data available.

\[\begin{equation} \begin{split} \ell(\theta) &= \sum_i \log\left(p_{\theta}(x_i\vert Z=z_i)p_{\theta}(Z=z_i) \right) \\\\ &= \sum_i \log\left(p_{\theta}(x_i, z_i) \right) \end{split} \end{equation}\]

The collection of \(({x_i}, {z_i})\) is called the complete data. Naturally, \({x_i}\) is the incomplete data and \({z_i}\) is the latent data/variable.

Roughly speaking, EM algorithm is an iterative method that let us to guess \(z_i\) based on \(x_i\) (and current estimate of model parameter \(\hat\theta\)). With the guessed “fill-in” \(z_i\) we now have complete data and we optimize the log likelihood \(\ell(\theta)\) with respect to \(\theta\). We thus iteratively improve our guess of latent variable \(z\) and parameter \(\theta\). We repeat this process until convergence.

In slightly more detail, instead of guessing a single value \(z\) we guess the distribution of \(z\) given \(x\), i.e. \(p(z\vert x;\hat\theta)\). then optimize the expected log likelihood for complete data, i.e. \(\sum_i \mathbb{E}_{z \sim p(z\vert x_i;\hat\theta)}\log p_\theta (x_i, z)\), with respect to \(\theta\) which serves as a proxy (lower bound) for the true objective \(\sum_i \log p_{\theta}(x_i)\). Repeat until converge.

(Note in fact guessing a single value for \(z\) is also a valid strategy. It corresponds to a variant of EM and is what we do in the well-known K-means algorithm, where we guess a “hard” label on each data points.)

The nice thing about EM is that it comes with theoretical guarantee of monotonic improvement on the true objective even through we directly work with a proxy (lower bound) of it. Note however the rate of convergence will depend on the problem and the convergence is not guaranteed to be towards the global optima.

Formulation

As before, we start with the log likelihood

\[\begin{equation} \begin{split} \ell(\theta) &= \sum_i \log\left( p_{\theta}(x_i) \right) \\\\ &= \sum_i \log\left( \int p_{\theta}(x_i, z) dz \right)\\\\ &= \sum_i \log\left( \int \frac{p_{\theta}(x_i, z)}{q(z)} q(z) dz \right) \\\\ &= \sum_i \log\left( \mathbb{E}_{z \sim Q} \left[ \frac {p_{\theta}(x_i, z)}{q(z)} \right] \right)\\\\ &\ge \sum_i \mathbb{E}_{z \sim Q} \left[\log\left( \frac {p_{\theta}(x_i,z)}{q(z)} \right) \right]\\\\ \label{eq:jensen} \end{split} \end{equation}\]

Here I switched the summation over \(z\) to integral assuming \(z\) is continuous, just to hint this is a possibility. The last step used Jensen’s inequality and the fact log function is strictly concave. So far we do not have any restrictions on the distribution \(Q\), apart from \(q(z)\) being a probability density function and it is positive where \(p_\theta(x_i,z)\) is.

Using the result above, let’s define the last quantity as \(\mathcal L(q,\theta)\). It is usually called ELBO (Evidence Lower BOund) as it is a lower bound of \(\ell(\theta)\).

\[\begin{equation} \mathcal L(q,\theta) = \sum_i \mathbb{E}_{z \sim Q} \left[\log\left( \frac {p_{\theta}(x_i,z)}{q(z)} \right) \right] \end{equation}\]

Just to reiterate what we have done so far: our goal is to maximize \(\ell(\theta)\), we exchanged the place of the log and integral over \(z\) and got a lower bound \(\mathcal L\).

We can show that the difference between \(\ell(\theta)\) and \(\mathcal L(q,\theta)\) is

\[\begin{equation} \begin{split} \ell(\theta) - \mathcal L(q,\theta) & = \sum_i \int q(z) \left(log(p_\theta(x_i)) - \log\left(\frac{p_\theta(x_i,z)}{q(z)}\right)\right) dz\\\\ &= \sum_i \int q(z) \log\left(\frac{q(z)}{\frac{p_\theta(x_i,z)}{p_\theta(x)}}\right) dz \\\\ &= \sum_i \int q(z) \log\left(\frac{q(z)}{p_\theta(z\vert x_i)}\right) dz \\\\ &= \sum_i D_{KL}(q(z) \| p_\theta(z\vert x_i)) \end{split} \end{equation}\]

where we used the fact Kullback-Leibler (KL) divergence \(D_{KL}\) is defined as

\[D_{KL}(P \| Q)= \int p(x) \log \left( \frac{p(x)}{q(x)} \right) dx = \mathbb{E}_{x\sim P}[\log(\frac{p(x)}{q(x)}]\]

In general, KL divergence is always nonnegative and is zero if and only if \(q(x) = p(x)\). So in our case, the equality \(\ell(\theta) = \mathcal L(q,\theta)\) holds if and only if \(q(z) = p_\theta(z\vert x_i)\). When this happens, we say the bound is tight. In this case, it makes sense to note \(q(z)\) as \(q(z\vert x_i)\) to make the dependence on \(x_i\) clear.

EM algorithm and monotonicity guarantee

The EM algorithm is remarkably simple and it goes as follows.

  • E-step (of \(t\)-th iteration):
    • Let \(q^t(z) = p(z \vert x_i; \hat\theta^{t-1})\), which is calculated as shown in Eq. \(\ref{eq:E}\)
    • Due to our particular choice of \(q^t\), at current estimate of \(\hat\theta^{t-1}\) the bond is tight: \(\mathcal L(q^t,\hat\theta^{t-1}) = \ell(\hat\theta^{t-1})\)
  • M-step
    • Maximize \(\mathcal L(q^t,\theta)\) with respect to \(\theta\), see Eq. \(\ref{eq:M}\)
    • This step improves ELBO by finding a better \(\theta\): \(\mathcal L(q^t,\theta^t) \ge \mathcal L(q^t,\theta^{t-1})\)

The calculation in E-step is

\[\begin{equation}\label{eq:E} p(z\vert x_i; \hat\theta^{t-1}) = \frac{p(x_i\vert z; \hat\theta^{t-1})p(z; \hat\theta^{t-1})}{\int p(x_i\vert z; \hat\theta^{t-1})p(z; \hat\theta^{t-1}) dz} \end{equation}\]

Just to spell out the function \(\mathcal L(q^t,\theta)\) that we maximize in M-step.

\[\begin{equation} \begin{split} \hat\theta^t &= \argmax_{\theta} \mathcal L(q^t,\theta) \\\\ &= \argmax_{\theta} \sum_i \mathbb{E}_{z \sim Q^t} \left[\log\left(p(x_i,z;\theta) \right) \right] \\\\ &= \argmax_{\theta} \sum_i \int p(z\vert x_i; \hat\theta^{t-1}) \log\left(p(x_i,z;\theta)\right) dz \\\\ \end{split} \label{eq:M} \end{equation}\]

With the preparation earlier we can also easily show the theoretical guarantee on monotonic improvement over the optimization objective \(\ell(\theta)\).

\[\begin{equation}\label{eq:monotone} \ell(\theta^{t-1}) \underset{E-step}{=} \mathcal L(q^t,\theta^{t-1}) \underset{M-step}{\le} \mathcal L(q^t,\theta^t) \underset{Jensen}{\le} \ell(\theta^{t}) \end{equation}\]

Why the “E” in E-step

By the way, the reason it is called E-step is because in that step we do the necessary calculation to figure out the form of \(\mathcal L(q,\theta)\) as a function of \(\theta\) which we then optimize in the M-step. The form of \(\mathcal L(q,\theta)\) is the expectation of the log likelihood of complete data over the estimated distribution of the latent variable \(z\).

EM as maximization-maximization

Because the particular choice \(q^t(z)\) in E-step is to have diminishing \(D_{KL}(q(z) \| p_\theta(z\vert x_i))\), thus E-step can be viewed as maximizing \(\mathcal L(q,\hat\theta^{t-1})\) with respect to \(q\) and M-step as maximization with respect to \(\theta\). So we are doing alternating maximization on the EBLO with respect to \(q\) and \(\theta\).

\[\begin{equation} \begin{split} & \text{E-step:}\hspace{4pt}q^t(z) = \argmax_q \mathcal L(q,\hat\theta^{t-1})\\\\ & \text{M-step:}\hspace{4pt}\hat\theta^t = \argmax_\theta \mathcal L(q^t,\theta) \end{split} \end{equation}\]

This maximization-maximization view offers justification for partial E-step (when the required calculation in exact E-step is intractable) and partial M-step (i.e. only find a \(\theta\) that increases the ELBO rather than maximizes it). Under this view, the direct maximization on ELBO as a goal offers a strong connection to variational inference as will be discussed briefly below.

Example: Gaussian Mixture

In the context of Gaussian Mixture Model (GMM), \(z_i\) associated with \(x_i\) takes the value \({1,2,\dots\,n_{g}}\), where \({n_g}\) is the number of Gaussians in the model. Thus \(z_i\) indicates which Gaussian cluster observed data point \(x_i\) belongs to. The set of parameter \(\theta\) includes those parameterize the marginal distribution of \(z\), \(P(Z;\vec \pi)\). \(\vec \pi = [\pi_1, \pi_2, \dots, \pi_{n_g}]\), with \(\sum_i^{n_g} \pi_i = 1\) and \(\pi_i > 0\). Also, \(\theta\) include those parametrized the conditional distribution of \(P(X \vert Z=z_i; \mu_i, \sigma_i) \sim \mathcal N(\mu_i, \sigma_i)\).

For a detailed walk-through see Andrew Ng’s CS229 lecture notes and video

Variants and extensions of EM

GEM and CEM

A popular variant to EM is that in Eq. \(\ref{eq:M}\) we merely find a \(\hat\theta^t\) that increases (rather than maximizes) \(\mathcal L(q^t,\theta)\). It is easy to see \(\ref{eq:monotone}\) and the monotonicity guarantee still holds in this situation. This algorithm is proposed in the original EM paper and called Generalized EM (GEM).

Another variant is the point-estimate version we mentioned earlier, where instead of having \(q^t(z) = p(z\vert x_i; \hat\theta^{t-1})\) in the E-step, we take \(z\) to be a single value - the most probable one, i.e. \(\hat{z}^t=argmax_z p(z\vert x_i; \hat\theta^{t-1})\) or equivalently taking \(q^t(z) = \delta(z-\hat{z}^t)\). In this case, the integral in \(\ref{eq:M}\) is greatly simplified, but the first equality in \(\ref{eq:monotone}\) does not hold any more and we lose the theoretical guarantee. This algorithm is also called Classification EM (CEM).

Stochastic EM

As we can see in Eq. \(\ref{eq:M}\), we need to go through all data points in order to update \(\theta\), which could be long process for large data sets. In much of the same spirit as stochastic gradient descent, we could sample subsets of data and run the E- and M-step on these mini batches. The same idea can be used for variational inference mentioned below, on the update of global latent variables (such as \(\theta\)).

Variational inference

The computation of the optimal \(q(z)\), i.e. \(q(z) = p(z \vert x_i; \hat\theta_{t-1})\) in E-step might be intractable. Especially, the integral in the denominator of Eq. \(\ref{eq:E}\) does not have closed form solution for many interesting models. In this case we can take the view of EM as maximization-maximization and try to come up with better and better \(q(z)\) to improve the ELBO. In order to proceed with such variational optimization tasks, we need to specify the functional family \(\mathcal Q\) from which we will choose \(q(z)\). Depending on the assumptions a number of interesting algorithms have been developed. The most popular one is probably mean-field approximation.

Note that in a typical variational inference framework, the parameter \(\theta\) is treated as first class variables that we would do inference on (i.e. getting \(p(\theta\vert x)\)) rather than taking a maximum likelihood single point estimation, so \(\theta\) become part of the latent variables and absorbed into the notation \(z\). Thus, \(z\) includes global variables such as \(\theta\) and local variables such as the latent labels \(z_i\) associated with each data point \(x_i\).

In mean-field method the constraint we put on \(q(z)\) is that it factorizes, i.e. \(q(z) = \prod_k q_k(z_k)\). This is saying that all latent variables are mutual independent, by assumption. This seemingly simple assumption brings in remarkable simplifications in the calculation of integrals and especially the expectations of log likelihood involved. It leads to a coordinate ascent variational inference (CAVI) algorithm that allows closed-form iterative calculation for certain model family. The coordinate updates on local variables corresponds to the E-step in EM, while the updates on global variables corresponds to the M-step in EM.

For more about this topic see: D. M. Blei, A. Kucukelbir, and J. D. McAuliffe, “Variational Inference: A Review for Statisticians,” J. Am. Stat. Assoc., vol. 112, no. 518, pp. 859–877, 2017.


References

Todo: add citation in text; for now just core dumped some references here

In no particular order:

  1. A. P. Dempster, N. M. Laird, and D. B. Rubin, “Maximum likelihood from incomplete data via the EM algorithm,” J. R. Stat. Soc. Ser. B Methodol., vol. 39, no. 1, pp. 1–38, 1977.

  2. R. M. Neal and G. E. Hinton, “A View of the Em Algorithm that Justifies Incremental, Sparse, and other Variants,” Learn. Graph. Model., pp. 355–368, 1998.

  3. J. A. Bilmes, “A Gentle Tutorial of the EM Algorithm and its Application to Parameter Estimation for Gaussian Mixture and Hidden Markov Models,” ReCALL, vol. 1198, no. 510, p. 126, 1998.

  4. A. Roche, “EM algorithm and variants: an informal tutorial,” pp. 1–17, 2011.

  5. M. R. Gupta, “Theory and Use of the EM Algorithm,” Found. Trends® Signal Process., vol. 4, no. 3, pp. 223–296, 2010.

  6. M. Jordan, Z. Ghahramani, T. S. Jaakkola, and L. K. Saul, “Introduction to variational methods for graphical models,” Mach. Learn., vol. 37, no. 2, pp. 183–233, 1999.

  7. M. J. Wainwright and M. Jordan, “Graphical Models, Exponential Families, and Variational Inference,” Found. Trends® Mach. Learn., vol. 1, no. 1–2, pp. 1–305, 2007.

  8. M. Hoffman, D. M. Blei, C. Wang, and J. Paisley, “Stochastic Variational Inference,” vol. 14, pp. 1303–1347, 2012.

  9. D. M. Blei, A. Kucukelbir, and J. D. McAuliffe, “Variational Inference: A Review for Statisticians,” J. Am. Stat. Assoc., vol. 112, no. 518, pp. 859–877, 2017.

  10. S. Mohamed, “Variational Inference for Machine Learning,” no. February, 2015.

  11. Z. Ghahramani, “Variational Methods The Expectation Maximization ( EM ) algorithm,” no. April, 2003.

Leave a comment