8.3 Diffusion Models
One desideratum that has emerged from our investigation of directed generative models is for the distribution of latent variables to be essentially structureless. For one thing, this makes it easy to sample the latent variables, and therefore to generate data (since directed models can be sampled with a single “ancestral pass” from root to leaves). For another, it accords with our high-level conceptual goal of explaining observations in terms of a simpler set of independent causes.
Seen from the perspective of the recognition model, this desideratum presents a paradoxical character: it seems that the corresponding goal for the recognition model is to destroy structure. Clearly, it cannot do so in an irreversible way (e.g., by multiplying the observations by zero and adding noise), or there will be no mutual information between observed and latent variables. However, there exists a class of models, known in statistical physics as diffusion processes, which gradually delete structure from distributions but which are nevertheless reversible [46].
A diffusion process can be described by a (very long) Markov chain that (very mildly) corrupts the data at every step. Now, for sufficiently mild data corruption, the process is reversible [46]. Still, we cannot simply apply Bayes’ rule, since it requires the prior distribution over the original, uncorrupted data—i.e., the data distribution, precisely what we want to sample from. On the other hand, it turns out that each reverse-diffusion step must take the same distributional form as the forward-diffusion step [46]. We still don’t know how to convert noisy observations into the parameters of this distribution, but perhaps this mapping can be learned.
In particular, suppose we pair a recognition model describing such a diffusion process with a generative model that is a Markov chain of the same length, and with the same form for its conditional distributions, but pointing in the other direction. Then training the generative model to assign high probability to the observation—or more precisely, lower the joint relative entropy—while making inferences under the recognition model will effectively oblige the generative model to learn to denoise the data at every step. That is, it will become a model of the reverse-diffusion process.
Notice that for this process to be truly reversible, the dimensionality of the data must stay constant across all steps of the Markov chain (including the observations themselves). Also note that, as lately described, the recognition distribution is fixed and requires no learnable parameters. We revisit this idea below.
The generative and recognition models.
The model can be expressed most elegantly if we use
i.e., a Markov chain. Notice that the generative prior is not parameterized. This accords with the intuition lately discussed that the model should convert structureless noise into the highly structured distribution of interest. The recognition model likewise simplifies according to the independence statements for a Markov chain:
We have omitted the usual
The joint relative entropy.
The joint relative entropy, we recall once again, is the difference between the joint cross entropy (between hybrid distribution and generative model) and the entropy of the hybrid distribution,
Given the conditional independencies of a Markov chain, then, the joint relative entropy reduces to a sum of conditional cross entropies (plus constant terms):
The
8.3.1 Gaussian diffusion models
Probably the most intuitive diffusion process is based on Gaussian noise; for example,
for some parameters
(8.33) |
The mean and variance of this denoising distribution can depend on the corrupted sample (
The interpretation is now clear.
Minimizing the joint relative entropy obliges the generative model to learn how to “undo” one step of corruption with Gaussian noise, for all steps
As for the implementation, we see that the integrals in Eq. 8.34 can be approximated with samples: first a draw (
Then we will carry out the expectation under
The recognition marginals.
A very useful upshot of defining the recognition model to consist only of scaling and the addition of Gaussian noise is that the distribution of any random variable under this model is Gaussian (conditioned, that is, on
Under this definition,
Consistency with these marginals requires linear-Gaussian transitions:
Note that
Likewise, by the law of total covariance,
Notice (what our notation implied) that
The recognition “posterior transitions.”
The other recognition distribution we require in order to use Eq. 8.35 is the “reverse-transition”
From Eq. 2.13, the posterior mean is a convex combination of the information from the prior and likelihood:
Assembling the cumulants, we have
The reverse-transition cross entropies, revisited.
We noted above that one way to evaluate the joint relative entropy for the diffusion model is to form Monte Carlo estimates of each of the summands in Eq. 8.34.
Naïvely, we could evaluate each summand with samples from three random variables (
In moving to the second line, the expectation of the quadratic form was taken under
How are we to interpret Eq. 8.39?
It would be convenient if this could also be expressed as the mean squared error between an uncorrupted sample and a predictor—call it
Note that this reparameterization loses no generality. In terms of this predictor, the joint relative entropy then becomes
So as in Eq. 8.34, minimizing the loss amounts to optimizing a denoising function.
But in this case, it is the completely uncorrupted data samples,
Up till now we have refrained from specifying
in which case Eq. 8.40 simplifies to the even more elegant
In words, each summand in Eq. 8.41 computes the mean squared error between the uncorrupted data
Implementation
The continuous-time limit.
Eq. 8.36 tells us that the data can be corrupted to an arbitrary position in the Markov chain with a single computation. Consequently, it is not actually necessary to run the chain sequentially from 1 to
with
Notice that
A connection to denoising score matching
There is in fact another illuminating reparameterization, this time in terms of the negative energy gradient††margin:
negative energy gradient
,
(with
Consider a random variable
can be thought of as a kernel-density estimate of the distribution of interest,
This equation says that the force of the marginal
Putting this together with Eq. 8.45, we see that for additive Gaussian noise,
This is sometimes called Tweedie’s formula††margin: Tweedie’s formula . This looks helpful—if we had in hand the posterior mean!
Now it is a fact that, of all estimators for
Now we examine the diffusion model in light of this procedure.
Fitting the generative model to the data has turned out to be equivalent to minimizing the mean squared error between
The force estimator
Alternatively, the force can be fit directly, rather than indirectly via the posterior mean.
That is easily done here as well, simply by rearranging Eq. 8.47 to reparameterize
for some arbitrary function (neural network)
The last step follow from reparameterization.
Intuitively,
In either case, each summand in Eq. 8.48 corresponds to an objective for denoising score matching; or, put the other way around, fitting a Gaussian (reverse-)diffusion model (Eqs. 8.30, 8.31, 8.37, and 8.33) amounts to running denoising score matching for many different kernel widths. And indeed, such a learning procedure has been proposed and justified independently under this description [48].