Chapter 7 Learning Invertible Generative Models

We saw in Chapter 2 that computing a generative model’s posterior is frequently impossible, because computing the normalizer p^(𝒚^;𝜽){\hat{p}\mathopen{}\mathclose{{}\left(\leavevmode\color[rgb]{.5,.5,.5}% \definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5% }\pgfsys@color@gray@fill{.5}\bm{\hat{y}}{};\bm{\theta}}\right)} requires either an intractable integral or a sum over an exponential number of terms (i.e., exponential in the number of configurations of the variables to be marginalized out, 𝑿^{\bm{\hat{X}}}). Nevertheless, we begin with models for which exact inference is tractable, where the algorithm is particularly elegant. Such models are sometimes described as invertiblemargin: invertible generative models .

In this classical version of EM, the two optimizations of the joint relative entropy (JRE) are carried out in consecutive steps. At discriminative step ii, the optimization is trivial:

pˇ(i+1)(𝒙ˇ|𝒚)=argminpˇ{DKL{p(𝒀)p^(𝒀;𝜽)}+DKL{pˇ(𝑿ˇ|𝒀)p^(𝑿ˇ|𝒀;𝜽(i))}}=p^(𝒙^|𝒚^;𝜽(i)).\begin{split}{\check{p}^{(i+1)}\mathopen{}\mathclose{{}\left(\leavevmode\color% [rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}\bm{\check{x}}{}% \middle|\leavevmode\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{% rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}\bm{y}{}% }\right)}&=\operatorname*{argmin}_{\check{p}}\mathopen{}\mathclose{{}\left\{% \operatorname*{\text{D}_{\text{KL}}}\mathopen{}\mathclose{{}\left\{{p\mathopen% {}\mathclose{{}\left({\bm{Y}}}\right)}\middle\|{\hat{p}\mathopen{}\mathclose{{% }\left({\bm{Y}};\bm{\theta}}\right)}}\right\}+\operatorname*{\text{D}_{\text{% KL}}}\mathopen{}\mathclose{{}\left\{{\check{p}\mathopen{}\mathclose{{}\left({% \bm{\check{X}}}\middle|{\bm{Y}}}\right)}\middle\|{\hat{p}\mathopen{}\mathclose% {{}\left({\bm{\check{X}}}\middle|{\bm{Y}};\bm{\theta}^{(i)}}\right)}}\right\}}% \right\}\\ &={\hat{p}\mathopen{}\mathclose{{}\left(\leavevmode\color[rgb]{.5,.5,.5}% \definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5% }\pgfsys@color@gray@fill{.5}\bm{\hat{x}}{}\middle|\leavevmode\color[rgb]{% .5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}\bm{\hat{y}}{};\bm{% \theta}^{(i)}}\right)}.\end{split}

The conclusion follows because the marginal relative entropy (MRE) doesn’t depend on the recognition model (pˇ\check{p}); because the posterior relative entropy (PRE), is minimal (0) when its arguments are equal; and because we have assumed the generative posterior distribution is computable, and therefore an available choice for the recognition model. Thus the algorithm becomes:

EM Algorithm under exact inference

\bullet\> E step: pˇ(i+1)(𝒙ˇ|𝒚)p^(𝒙^|𝒚^;𝜽(i)){\check{p}^{(i+1)}\mathopen{}\mathclose{{}\left(\leavevmode\color[rgb]{% .5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}\bm{\check{x}}{}% \middle|\leavevmode\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{% rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}\bm{y}{}% }\right)}\leftarrow{\hat{p}\mathopen{}\mathclose{{}\left(\leavevmode\color[rgb% ]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}\bm{\hat{x}}{}\middle|% \leavevmode\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{% .5,.5,.5}\pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}\bm{\hat{y}}{% };\bm{\theta}^{(i)}}\right)}
\bullet\> M step: 𝜽(i+1)argmin𝜽{JRE(𝜽,pˇ(i+1))}\bm{\theta}^{(i+1)}\leftarrow\operatorname*{argmin}_{\bm{\theta}}\mathopen{}% \mathclose{{}\left\{\mathcal{L}_{\text{JRE}}(\bm{\theta},\check{p}^{(i+1)})}\right\}.

That is, we will carry out the optimization in Eq. 6.5 with an iterative process, at each iteration of which we take expectations under the generative posterior from the previous iteration. For example, to update our estimate of the mean of the GMM, we will indeed use Eq. 6.4, except that the posterior under which the expectations are taken will be evaluated at the parameters from the previous iteration. This eliminates the dependence on 𝝁k\bm{\mu}_{k} from the right-hand side of the equation.

As we saw in the previous chapter, this procedure is guaranteed to decrease (or, more precisely, not to increase) an upper bound on the loss we actually care about, the MRE. In this version of the algorithm, we can say more. At each E step, the PRE in Eq. 6.6 is not merely reduced but eliminated. So at the start of every M step, the bound is tight (Fig. LABEL:fig:EMtightBound). That means that any decrease in JRE at this point entails a decrease in MRE—a better model for the data. Typically, decreases in JRE will also be accompanied by an increase in the PRE. But so far from being a bad thing, this corresponds to an even larger decrease in MRE than the decrease in JRE (see again Fig. LABEL:fig:EMtightBound). And at the next step E step, the PRE is again eliminated—the bound is again made tight.

In the examples we consider next, the M step is carried out in closed form, so every parameter update either decreases the MRE or does nothing at all. In contrast, for models in which the M step is carried out by gradient descent, a decrease in JRE need not correspond to a decrease in MRE: As the JRE decreases across the course of the M step, the PRE can open back up—the bound can loosen—and therefore further decreases can correspond to the bound retightening, rather than the MRE decreasing. Indeed, the MRE can increase during this period of bound tightening. But it can never increase above its value at at the beginning of the M step, when the bound was tight.

Applying EM.

The EM “algorithm” is a in fact a kind of meta-algorithm or recipe for estimating densities with latent variables. To derive actual algorithms we need to apply EM to specific graphical models. In the following sections we apply EM to some of the most classic latent-variable models.

Our recipe instructs us to minimize joint relative entropy between pˇ(𝒙ˇ|𝒚)p(𝒚){\check{p}\mathopen{}\mathclose{{}\left(\leavevmode\color[rgb]{.5,.5,.5}% \definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5% }\pgfsys@color@gray@fill{.5}\bm{\check{x}}{}\middle|\leavevmode\color[rgb]{% .5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}\bm{y}{}}\right)}{p% \mathopen{}\mathclose{{}\left(\leavevmode\color[rgb]{.5,.5,.5}\definecolor[% named]{pgfstrokecolor}{rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5}% \pgfsys@color@gray@fill{.5}\bm{y}{}}\right)} and p^(𝒙^,𝒚^;𝜽){\hat{p}\mathopen{}\mathclose{{}\left(\leavevmode\color[rgb]{.5,.5,.5}% \definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5% }\pgfsys@color@gray@fill{.5}\bm{\hat{x}}{},\leavevmode\color[rgb]{.5,.5,.5}% \definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5% }\pgfsys@color@gray@fill{.5}\bm{\hat{y}}{};\bm{\theta}}\right)}. However, to keep the derivation general, we will avoid specifying pˇ(𝒙ˇ|𝒚){\check{p}\mathopen{}\mathclose{{}\left(\leavevmode\color[rgb]{.5,.5,.5}% \definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5% }\pgfsys@color@gray@fill{.5}\bm{\check{x}}{}\middle|\leavevmode\color[rgb]{% .5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}\bm{y}{}}\right)} until as late as possible. Thus, the following derivations apply equally well to fully observed models, in which case

pˇ(𝒙ˇ,𝒚)p(𝒙,𝒚),{\check{p}\mathopen{}\mathclose{{}\left(\leavevmode\color[rgb]{.5,.5,.5}% \definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5% }\pgfsys@color@gray@fill{.5}\bm{\check{x}}{},\leavevmode\color[rgb]{.5,.5,.5}% \definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5% }\pgfsys@color@gray@fill{.5}\bm{y}{}}\right)}\leftarrow{p\mathopen{}\mathclose% {{}\left(\leavevmode\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{% rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}\bm{x}{}% ,\leavevmode\color[rgb]{.5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{% .5,.5,.5}\pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}\bm{y}{}}% \right)},

as they do to a single step in EM, in which case

pˇ(𝒙ˇ,𝒚)p^(𝒙^|𝒚;𝜽old)p(𝒚).{\check{p}\mathopen{}\mathclose{{}\left(\leavevmode\color[rgb]{.5,.5,.5}% \definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5% }\pgfsys@color@gray@fill{.5}\bm{\check{x}}{},\leavevmode\color[rgb]{.5,.5,.5}% \definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5% }\pgfsys@color@gray@fill{.5}\bm{y}{}}\right)}\leftarrow{\hat{p}\mathopen{}% \mathclose{{}\left(\leavevmode\color[rgb]{.5,.5,.5}\definecolor[named]{% pgfstrokecolor}{rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5}% \pgfsys@color@gray@fill{.5}\bm{\hat{x}}{}\middle|\leavevmode\color[rgb]{% .5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}\bm{y};\bm{\theta}^{% \text{old}}}\right)}{p\mathopen{}\mathclose{{}\left(\leavevmode\color[rgb]{% .5,.5,.5}\definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}% \pgfsys@color@gray@stroke{.5}\pgfsys@color@gray@fill{.5}\bm{y}{}}\right)}.

In either case, the entropy of pˇ(𝒙ˇ,𝒚){\check{p}\mathopen{}\mathclose{{}\left(\leavevmode\color[rgb]{.5,.5,.5}% \definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5% }\pgfsys@color@gray@fill{.5}\bm{\check{x}}{},\leavevmode\color[rgb]{.5,.5,.5}% \definecolor[named]{pgfstrokecolor}{rgb}{.5,.5,.5}\pgfsys@color@gray@stroke{.5% }\pgfsys@color@gray@fill{.5}\bm{y}{}}\right)} is irrelevant to the optimization, since it doesn’t depend on the parameters 𝜽\bm{\theta} optimized in the M step, and the E step is trivial. Therefore we begin all derivations with the joint cross, rather than relative, entropy.