8.3 Diffusion Models
One desideratum that has emerged from our investigation of directed generative models is for the distribution of latent variables to be essentially structureless. For one thing, this makes it easy to sample the latent variables, and therefore to generate data (since directed models can be sampled with a single “ancestral pass” from root to leaves). For another, it accords with our high-level conceptual goal of explaining observations in terms of a simpler set of independent causes.
Seen from the perspective of the recognition model, this desideratum presents a paradoxical character: it seems that the corresponding goal for the recognition model is to destroy structure. Clearly, it cannot do so in an irreversible way (e.g., by multiplying the observations by zero and adding noise), or there will be no mutual information between observed and latent variables. However, there exists a class of models, known in statistical physics as diffusion processes, which gradually delete structure from distributions but which are nevertheless reversible [46].
A diffusion process can be described by a (very long) Markov chain that (very mildly) corrupts the data at every step. Now, for sufficiently mild data corruption, the process is reversible [46]. Still, we cannot simply apply Bayes’ rule, since it requires the prior distribution over the original, uncorrupted data—i.e., the data distribution, precisely what we want to sample from. On the other hand, it turns out that each reverse-diffusion step must take the same distributional form as the forward-diffusion step [46]. We still don’t know how to convert noisy observations into the parameters of this distribution, but perhaps this mapping can be learned.
In particular, suppose we pair a recognition model describing such a diffusion process with a generative model that is a Markov chain of the same length, and with the same form for its conditional distributions, but pointing in the other direction. Then training the generative model to assign high probability to the observation—or more precisely, lower the joint relative entropy—while making inferences under the recognition model will effectively oblige the generative model to learn to denoise the data at every step. That is, it will become a model of the reverse-diffusion process.
Notice that for this process to be truly reversible, the dimensionality of the data must stay constant across all steps of the Markov chain (including the observations themselves). Also note that, as lately described, the recognition distribution is fixed and requires no learnable parameters. We revisit this idea below.
The generative and recognition models.
The model can be expressed most elegantly if we use for the observed data , and likewise for their counterparts in the generative model, —so in this section we do. Then the diffusion generative model can be written simply as
i.e., a Markov chain. Notice that the generative prior is not parameterized. This accords with the intuition lately discussed that the model should convert structureless noise into the highly structured distribution of interest. The recognition model likewise simplifies according to the independence statements for a Markov chain:
We have omitted the usual from this model because by assumption it has no learnable parameters.
The joint relative entropy.
The joint relative entropy, we recall once again, is the difference between the joint cross entropy (between hybrid distribution and generative model) and the entropy of the hybrid distribution, . However, in the diffusion model, the recognition model is parameterless, so the entire entropy term is a constant as far as our optimization goes. That is, our optimization is essentially “all M step” (improving the generative model), and therefore can be carried out on the joint cross entropy. This would be a mistake if the recognition model were (as usual) to be intepreted as merely an approximation to the posterior under the generative model. In the diffusion model, in contrast, the recognition model is interpreted as ground truth, for which the generative model provides the approximation. The overall loss can be reduced by improving the generative fit either to the data or to the recognition model, but in this case both are desirable per se.
Given the conditional independencies of a Markov chain, then, the joint relative entropy reduces to a sum of conditional cross entropies (plus constant terms):
The denotes different constants on different lines. Now we need to choose generative and recognition models that make the integrals (or sums) in these cross entropies tractable.
8.3.1 Gaussian diffusion models
Probably the most intuitive diffusion process is based on Gaussian noise; for example,
for some parameters and . Essentially the process scales (down) the data and adds isotropic noise. However, we defer specifying these parameters for the moment, and turn directly to the generative model. Suffice to say, if is sufficiently close to 1 and is sufficiently small, then the generative transitions are likewise Gaussian (see discussion above). Furthermore, after many diffusion steps, the distribution of the state will be Gaussian and isotropic. With the appropriate selection of the recognition parameters, we can force this distribution to have zero mean and unit variance. Therefore we define the generative model to be
(8.33) |
The mean and variance of this denoising distribution can depend on the corrupted sample () in a complicated way, so in general we can let and be neural networks. Nevertheless, for simplicity in the derivation, and because learning variances is significantly more challenging than learning means, let us further replace the variance function with a set of fixed (i.e., not learned), data-independent parameters, . (To save space, we also write the mean function, , but it certainly does depend on the data and generative-model parameters.) Then the joint relative entropy (Eq. 8.32) can be expressed as
The interpretation is now clear. Minimizing the joint relative entropy obliges the generative model to learn how to “undo” one step of corruption with Gaussian noise, for all steps . At the model’s disposal is the arbitrarily powerful function (neural network) . Since the amount of corruption can vary with , optimizing Eq. 8.34 obliges to be able to remove noise of possibly different sizes.
As for the implementation, we see that the integrals in Eq. 8.34 can be approximated with samples: first a draw () from the data distribution, followed by draws () down the length of the recognition model. Notice, however, that under this scheme, each summand would be estimated with samples from three random variables, . We can reduce this by one, and thereby reduce the variance of our Monte Carlo estimator, by exploiting some properties of Gaussian noise. In particular, we will reverse the order of expansion in applying the chain rule of probability to the recognition model:
Then we will carry out the expectation under in closed form. In preparation, we now turn to and .
The recognition marginals.
A very useful upshot of defining the recognition model to consist only of scaling and the addition of Gaussian noise is that the distribution of any random variable under this model is Gaussian (conditioned, that is, on ). That includes , which we might call the “recognition marginal.” For reasons that will soon become clear, we make these marginals the starting point of our definition of the recognition model [26], and then work backwards to the transition probabilities:
Under this definition, is a kind of signal-to-noise ratio. We require it to decrease monotonically with .
Consistency with these marginals requires linear-Gaussian transitions:
Note that and need not even be consecutive steps in the Markov chain, although we require . Furthermore, by the law of total expectation,
Likewise, by the law of total covariance,
Notice (what our notation implied) that and are necessarily scalars. In fine, the conditional recognition probabilities are given by
The recognition “posterior transitions.”
The other recognition distribution we require in order to use Eq. 8.35 is the “reverse-transition” . Here we again solve for the more generic case of in which the state precedes but not necessarily directly. This distribution is again normal (all recognition distributions are), although this time the calculation of the cumulants is slightly more complicated, since it requires Bayes’ rule. Here the “prior” is (and given by the definition of the recognition marginals, Eq. 8.36); and the “likelihood” (or emission) is (the resulting conditional recognition probabilities, given by Eq. 8.37). We have worked out the general case of Bayes rule for jointly Gaussian random variables in Section 2.1.2. From Eq. 2.14, the posterior precision is the sum of the (unnormalized) prior and likelihood precisions (in the space of ):
From Eq. 2.13, the posterior mean is a convex combination of the information from the prior and likelihood:
Assembling the cumulants, we have
The reverse-transition cross entropies, revisited.
We noted above that one way to evaluate the joint relative entropy for the diffusion model is to form Monte Carlo estimates of each of the summands in Eq. 8.34. Naïvely, we could evaluate each summand with samples from three random variables (), but as we also noted, one of the expectations can actually be taken in closed-form. In particular, if we expand according the chain rule of probability given by Eq. 8.35—namely, in terms of the two distributions just derived, Eqs. 8.36 and 8.38 (letting be and be )—then we can write
In moving to the second line, the expectation of the quadratic form was taken under using the identity B.13 from the appendix, except that the trace term, again a function of the fixed variance, was absorbed into the (now different) constant . The remaining expectations can be approximated with Monte Carlo estimates, since we have in hand samples from the data distribution, , and it is straightforward to generate samples of from Eq. 8.36.
How are we to interpret Eq. 8.39? It would be convenient if this could also be expressed as the mean squared error between an uncorrupted sample and a predictor—call it —that has access only to corrupted samples (), as in Eq. 8.34. Of course it can, if we simply reparameterize the generative mean function on analogy with the mean of (Eq. 8.38):
Note that this reparameterization loses no generality. In terms of this predictor, the joint relative entropy then becomes
So as in Eq. 8.34, minimizing the loss amounts to optimizing a denoising function. But in this case, it is the completely uncorrupted data samples, , that are to be recovered, and accordingly a different (but related) denoising function/neural network () that is to be used. The integrals in Eq. 8.40 are again to be estimated with samples, but from only two rather than three random variables.
Up till now we have refrained from specifying . However, we now note that the expectations carried out in Eq. 8.39 amount to computing the cross entropy between and . Since cross entropy is minimized when the distributions are equal, it seems sensible simply to equate their variances.99 9 Nevertheless, this is not quite optimal. We will not in general be able to set the means of these distributions precisely equal, so the variance really ought to soak up the difference. Comparing Eqs. 8.38 and 8.33, we have
in which case Eq. 8.40 simplifies to the even more elegant
In words, each summand in Eq. 8.41 computes the mean squared error between the uncorrupted data and a denoised version of the corrupted data, . But before summing, the MSE at step is weighted by the amount of SNR lost in transitioning from step to . In fact, Eq. 8.41 tells us that fitting a Gaussian reverse-diffusion model is equivalent to fitting a (conceptually) different generative model:
Implementation
The continuous-time limit.
Eq. 8.36 tells us that the data can be corrupted to an arbitrary position in the Markov chain with a single computation. Consequently, it is not actually necessary to run the chain sequentially from 1 to during training, which is critical for parallelized implementations. Indeed, we need not even limit ourselves to an integer number of steps. Suppose we allow the SNR to be a monotonically decreasing function of a continuous variable that ranges from 0 to 1, such that . For consistency, we will define another function on for the marginal variance, such that , although need not be monotonic. Then if we scale the joint relative entropy in Eq. 8.41 by the “step size” and take the limit as , the loss becomes
with a uniformly distributed random variable. This is the preferred implementation of diffusion models [25]. But the second line also suggests the change of variables , under which the integral becomes
Notice that can be safely removed from this equation, since we have not yet committed to any particular , and is assumed to be arbitrarily flexible. This shows that in the continuous-time limit, any choice of SNR function yields the same joint relative entropy in expectation, as long as (1) it is monotonically decreasing and (2) it has well-chosen endpoints, and . However, this choice does affect the variance of this sample average, and various SNR “schedules” have been experimented with in practice [25].
A connection to denoising score matching
There is in fact another illuminating reparameterization, this time in terms of the negative energy gradient††margin: negative energy gradient , . Because this quantity can be written (generically) as , it is sometimes called the score function for its resemblance to . We will call it the force††margin: force . Intuitively, the force points toward the modal . Furthermore, if our goal in fitting a generative model is merely to synthesize new data, then the force suffices, because the iteration
(with , and step size ) can be shown to generate samples approximately from the distribution . This iteration is known as Langevin dynamics††margin: Langevin dynamics , and we return to it in Section 10.2.1. For now we simply ask how we might get or estimate the force.
Consider a random variable that was created by corrupting the random variable of interest, , with some kind of additive noise. The resulting marginal distribution of ,
can be thought of as a kernel-density estimate of the distribution of interest, , with data samples and kernel . So perhaps we can use the former in the place of the latter in our Langevin dynamics, Eq. 8.44. But then how are we to get the force of ? Expanding it, we find that
This equation says that the force of the marginal equals the expected (under ) force of the conditional, . The latter can often be computed. For example, if the data are corrupted by scaling and then adding zero-mean Gaussian noise, then the conditional energy and its expected negative gradient (force) are
Putting this together with Eq. 8.45, we see that for additive Gaussian noise,
This is sometimes called Tweedie’s formula††margin: Tweedie’s formula . This looks helpful—if we had in hand the posterior mean!
Now it is a fact that, of all estimators for , the posterior mean has minimum mean squared error [42]. Putting all these pieces together [18, 52, 23] yields the following procedure to generate samples from the distribution of interest, : (1) Find an estimator for that minimizes mean square error; (2) use this in place of the posterior mean in Eq. 8.46 to compute the expected conditional force and, consequently, the marginal force; (3) use the marginal force, , as a proxy for the data force, , in Langevin dynamics (Eq. 8.44). This method of density estimation is known as denoising score matching††margin: denoising score matching .
Now we examine the diffusion model in light of this procedure. Fitting the generative model to the data has turned out to be equivalent to minimizing the mean squared error between and (Eq. 8.41). Therefore we can interpret as (an estimator for) the posterior mean, . The samples are generated by a recognition model that corrupts the data samples with Gaussian noise (Eq. 8.36). Therefore we can use the posterior-mean estimator and Tweedie’s formula (Eq. 8.46) to construct a force estimator:
The force estimator also provides a good proxy for the data force, , and consequently can be used to generate data with Langevin dynamics (Eq. 8.44).
Alternatively, the force can be fit directly, rather than indirectly via the posterior mean. That is easily done here as well, simply by rearranging Eq. 8.47 to reparameterize (again without loss of generality):
for some arbitrary function (neural network) . Under this reparameterization, Eq. 8.41 becomes
The last step follow from reparameterization. Intuitively, learns, like , how to uncorrupt data. But rather than transforming the corrupted sample () into an estimate of the (scaled) uncorrupted sample itself (), a good produces (second line of Eq. 8.48) an estimate of the vector that points back to from the corrupted sample . This is consistent with our conclusion that any that satisfies Eq. 8.47 provides an estimator for the force of the data distribution. Alternatively, the final line of Eq. 8.48 tells us that the force estimator must try to recover each realizaton of noise () that corrupted each observed datum (). But notice that the negative force, i.e. the positive energy gradient, must point in the direction of . This makes sense: we expect the noise to be “uphill.”
In either case, each summand in Eq. 8.48 corresponds to an objective for denoising score matching; or, put the other way around, fitting a Gaussian (reverse-)diffusion model (Eqs. 8.30, 8.31, 8.37, and 8.33) amounts to running denoising score matching for many different kernel widths. And indeed, such a learning procedure has been proposed and justified independently under this description [48].