Menu Logo

Modular Diffusion

GitHub Discord PyPI

Custom Modules

When tinkering with Diffusion Models, the time will come when you need to venture beyond what the base library offers and modify the diffusion process to fit your needs. Modular Diffusion meets this requirement by providing an abstract base class for each module type, which can be extended to define custom behavior. In this tutorial, we provide an overview of each base class and an example of how to extend it.

Type annotations

As with all library code, this tutorial adheres to strict type checking standards. Although we recommend typing your code, you may elect to avoid writing type annotations. By skipping this step, however, you will not receive a warning if you try to mix incompatible modules, or other useful intellisense.

Data transform

In many Diffusion Model applications, the diffusion process takes place in the dataset space. If this is your case, the prebuilt Identity data transform module will serve your purposes, leaving your data untouched before applying noise during training. However, a growing number of algorithms, like Stable Diffusion and Diffusion-LM, project data onto a latent space before applying diffusion.

In the case of Diffusion-LM, the dataset consists of sequences of word IDs, but the diffusion process happens in the word embedding space. This means you need a way of converting sequences of word IDs into sequences of embeddings, and train the embeddings along with the Diffusion Model. In Modular Diffusion, this can be achieved by extending the Data base class and implement its encode and decode methods. The former projects the data into the latent space and the latter retrieves it to the dataset space. Let’s take a look at how you could implement the aforementioned transform:

from diffusion.base import Data

@dataclass
class Embedding(Data):
    count: int = 2
    dimension: int = 256

    def __post_init__(self) -> None:
        self.embedding = nn.Embedding(self.count, self.dimension)

    def encode(self, w: Tensor) -> Tensor:
        return self.embedding(w)

    def decode(self, x: Tensor) -> Tensor:
        return torch.cdist(x, self.embedding.weight).argmin(-1)

In the encode method, we are transforming the input tensor w into an embedding tensor using the learned embedding layer. The decode method reverses this operation, by finding the most similar embedding in the embedding weight matrix to each vector in x.

Data transforms can also be useful in cases where they have no trainable parameters. For example, the Categorical noise module operates over one-hot vectors, which are very memory-inneficient. To mitigate this, you may store your data as a list of labels and use the OneHot data transform module to transform it into one-hot vectors on a batch-by-batch basis, saving you a lot of memory. Or your data transform can just be a frozen variational autoencoder, like in Stable Diffusion. For further details, check out our Text Generation and Image Generation tutorials.

Noise schedule

You can implement your own custom diffusion schedule by extending the abstract Schedule base class and implement its only abstract method, compute. This method is responsible for providing a tensor containing the values for αt\alpha_t for t{0,,T}t \in \{0,\dots,T\}. As an example, let’s implement the Linear schedule, which is already included in the library:

from dataclasses import dataclass
from diffusion.base import Schedule

@dataclass
class Linear(Schedule):
    start: float
    end: float

    def compute(self) -> Tensor:
        return torch.linspace(self.start, self.end, self.steps + 1)

Given that steps is already a parameter in the base class, all we need to do is define start and end parameters, and use them to compute the ata_t values. Then, we can initialize the schedule with the syntax Linear(steps, start, end).

Probability distribution

In the diffusion process, the chosen probability distribution plays a crucial role in modeling the noise that guides the transition between different states. The library comes prepackaged with a growing set of commonly used distributions, such as the Normal distribution, but different applications or experimental setups might require you to implement your own.

To define a custom distribution, you’ll need to extend the Distribution base class and implement three key methods: sample, which draws a sample from the distribution and returns a tuple containing the sampled value and the applied noise (or None if not applicable); nll, which computes the negative log-likelihood of the given tensor x; and dkl, which computes the Kullback-Leibler Divergence between the distribution and another provided as other. Take, for example, the Normal distribution, included in the library:

@dataclass

class Normal(Distribution):
    mu: Tensor
    sigma: Tensor

    def sample(self) -> tuple[Tensor, Tensor]:
        epsilon = torch.randn(self.mu.shape, device=self.mu.device)
        return self.mu + self.sigma * epsilon, epsilon

    def nll(self, x: Tensor) -> Tensor:
        return (0.5 * ((x - self.mu) / self.sigma)**2 + \
        (self.sigma * 2.5066282746310002).log())

    def dkl(self, other: Self) -> Tensor:
        return (torch.log(other.sigma / self.sigma) + \
        (self.sigma**2 + (self.mu - other.mu)**2) / (2 * other.sigma**2) - 0.5)

Parameter shapes

The distribution parameters are represented as tensors with the same size as a batch. This essentially means that a Distribution object functions as a collection of distributions, where each individual element in a batch corresponds to a unique distribution. For instance, each pixel in a batch of images is associated with its own mu and sigma values.

Noise type

In most Diffusion Model applications, the standard choice of noise is Gaussian, which is already bundled within the library. However, there may be scenarios where you want to experiment with variations of the standard Gaussian noise, as in DDIM introduced in Song et al. 2020, or venture into entirely different noise types, like the one used in D3PM, introduced in Austin et al. (2021). To create your own unique noise behavior, you will need to extend the abstract Noise base class, and implement each one of the following methods:

  • schedule(self, alpha: Tensor) -> None: This method is intended for precomputing resources based on the noise schedule αt\alpha_t for t0,,Tt \in {0,\dots,T}. This can be beneficial for performance reasons when some calculations can be done ahead of time. A common use is calculating αˉt=t=1Tαt\bar{\alpha}_{t}=\prod_{t=1}^{T}\alpha_{t}.
  • stationary(self, shape: tuple[int, ...]) -> Distribution: This method is tasked with computing the stationary distribution q(xT)q(x_T), i.e., the noise distribution at the final time step, given a target shape.
  • prior(self, x: Tensor, t: Tensor) -> Distribution: This method computes the prior distribution q(xtx0)q(x_t | x_0), i.e., the distribution of the noisy images xtx_t or z given the initial image x0x_0 or x.
  • posterior(self, x: Tensor, z: Tensor, t: Tensor) -> Distribution: This method computes the posterior distribution q(xt1xt,x0)q(x_{t-1} | x_t, x_0), i.e., the distribution of the less noisy images xt1x_{t-1} given the current noisy image xtx_t or z and the initial image x0x_0 or x.
  • approximate(self, z: Tensor, t: Tensor, hat: Tensor) -> Distribution: This method computes the approximate posterior distribution pθ(xt1xt)p_\theta(x_{t-1} | x_t), i.e., the distribution of the less noisy images xt1x_{t-1} given the current noisy image xtx_t or z. This is an approximation to the true posterior distribution that is easier to sample from or compute. The tensor hat is the output of the denoiser network containing the predicted parameters — named this way because predicted values often are denoted with a hat, e.g., ϵ^\hat{\epsilon}.

If you aim to replicate a specific research paper, only need to translate the mathematical expressions into code. For example, the original DDPM paper yields the following equations:

  • q(xT)=N(xT;0,I)q(x_{T})=\mathcal{N}(x_T; 0, \text{I})
  • q(xtx0)=N(xt;αˉtxt1,(1αˉt)I)q(x_{t}|x_{0})=\mathcal{N}(x_{t};\sqrt{\bar{\alpha}_{t}}x_{t-1},(1 - \bar{\alpha}_{t})\text{I})
  • q(xt1xt,x0)=N(xt;αt(1αˉt1)xt+αˉt1(1αt)x01αˉt,(1αt)(1αˉt1)1αˉtI)q(x_{t-1}|x_{t},x_{0})=\mathcal{N}(x_{t};\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})x_{t} + \sqrt{\bar\alpha_{t-1}}(1-\alpha_t)x_0}{1 -\bar\alpha_{t}},\frac{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}{1 -\bar\alpha_{t}}\text{I})
  • pθ(xt1xt)=N(xt;1αtxt1αt1αˉtαtϵ,(1αt)(1αˉt1)1αˉtI)p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t};\frac{1}{\sqrt{\alpha_t}}x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar\alpha_t}\sqrt{\alpha_t}}\epsilon,\frac{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}{1 -\bar\alpha_{t}}\text{I})

where αˉt=t=1Tαt\bar{\alpha}_{t}=\prod_{t=1}^{T}\alpha_{t} is calculated beforehand for better performance. In Modular Diffusion, here’s how we could implement this type of Gaussian noise:

from diffusion.base import Noise
from diffusion.distribution import Normal as N

@dataclass
class Gaussian(Noise[N]):
    def schedule(self, alpha: Tensor) -> None:
        self.alpha = alpha
        self.delta = alpha.cumprod(0)

    def stationary(self, shape: tuple[int, ...]) -> N:
        return N(torch.zeros(shape), torch.ones(shape))

    def prior(self, x: Tensor, t: Tensor) -> N:
        t = t.view(-1, *(1,) * (x.dim() - 1))
        return N(self.alpha[t].sqrt() * x, (1 - self.delta[t]).sqrt())

    def posterior(self, x: Tensor, z: Tensor, t: Tensor) -> N:
        t = t.view(-1, *(1,) * (x.dim() - 1))
        mu = self.alpha[t].sqrt() * (1 - self.delta[t - 1]) * z
        mu += self.delta[t - 1].sqrt() * (1 - self.alpha[t]) * x
        mu /= (1 - self.delta[t])
        sigma = (1 - self.alpha[t]) * (1 - self.delta[t - 1]) / (1 - self.delta[t])
        sigma = sigma.sqrt()
        return N(mu, sigma)

    def approximate(self, z: Tensor, t: Tensor, hat: Tensor) -> N:
        t = t.view(-1, *(1,) * (z.dim() - 1))
        mu = (z - (1 - self.alpha[t]) / (1 - self.delta[t]).sqrt() * hat[0])
        mu /= self.alpha[t].sqrt()
        sigma = (1 - self.alpha[t]) * (1 - self.delta[t - 1]) / (1 - self.delta[t])
        sigma = sigma.sqrt()
        return N(mu, sigma)

Broadcasting

You will notice that some methods start with a statement that reshapes the tensor t. This only done to allow broadcasting of the tensors in the subsequent operations. For instance, in the prior method, we need to multiply self.alpha[t].sqrt() by x, but self.alpha has shape [t] and x has shape [b, c, h, w]. By reshaping t to [b, 1, 1, 1], we can multiply self.alpha[t].sqrt() by x without any issues.

The schedule method precomputes alpha and delta (cumulative product of alpha) values, which are used in the other methods. The stationary method defines the initial noise distribution, while prior, posterior, and approximate methods implement the corresponding mathematical equations for the prior, posterior, and approximate posterior distributions. Collectively, these methods define the complete Gaussian noise model from the original DDPM paper. Note that it is possible to achieve a more efficient solution by precomputing some of the recurrent expressions used in the methods.

Denoiser neural network

Modular Diffusion comes with general-use UNet and Transformer classes, which have proven to be effective denoising networks in the context of Diffusion Models. However, it is not uncommon to see authors make modifications to these networks to achieve even better results. To design your own original network, extend the base abstract Net class. This class acts as only a thin wrapper over the standard Pytorch nn.Module class, meaning you can use it exactly the same way. The forward method should take three tensor arguments: the noisy input x, the conditioning matrix y, and the diffusion time steps t.

Network output shape

When creating your neural network, it’s important to remember that the first dimension of its output will be interpreted as the parameter index, irrespective of the number of parameters being predicted. For instance, if your network is predicting both the mean and variance of noise in an image, the output shape should be [2, c, h, w]. But even if you’re predicting only the mean, the shape should be [1, c, h, w] — not [c, h, w].

In scenarios where your network requires only a post-processing step, such as applying a Softmax function, there’s no need to create an entirely new network class. Modular Diffusion allows for a more concise approach using the pipe operator, as shown in the Getting Started tutorial:

from diffusion.net import Transformer
from torch.nn import Softmax

net = Transformer(input=512) | Softmax(3)

Loss function

In each training step, your Model instance creates a Batch object, which contains all the information you need about the current batch to compute the corresponding loss. To create a custom loss function, you can extend from the Loss base class and implement the compute method, where the loss is calculated based on the current batch. Let’s start by implementing LsimpleL_\text{simple} introduced in Ho et al. 2020. The formula for this loss function is E[ϵϵ^θ2]\mathbb{E}\left[ \lvert\lvert \epsilon - \hat{\epsilon}_\theta \rvert\rvert ^2 \right], where ϵ\epsilon is the noise added and ϵ^θ\hat{\epsilon}_\theta is the predicted noise.

from diffusion.base import Distribution, Loss

class Simple(Loss[Distribution]):
    def compute(self, batch: Batch[Distribution]) -> Tensor:
        return ((batch.epsilon - batch.hat[0])**2).mean()

Notice how we parametrize the Loss and Batch classes with the Distribution type. This just tells your IDE you can use this loss class for any kind of distribution. If you’d like to make a loss function that is only compatible with, say, Normal distributions, you should specify this inside the square brackets. Another thing to note is how we assume that the first parameter in the denoiser neural network output hat (named this way because predictions are often denoted with a little hat) is ϵ^θ\hat{\epsilon}_\theta. You can alter this behavior by changing the index or even make it parametrizable with a class property.

In certain scenarios, you might need not to compute your loss using batch.hat directly but instead utilize the approximate posterior distribution pθ(xt1xt)p_\theta(x_{t-1} | x_t), which itself is estimated from batch.hat in the Noise module. This is the case when you need to compute the variational lower bound (VLB), the original loss function utilized to train Diffusion Models. The formula for the VLB is expressed as:

Lvlb=Eq(x1x0)[logpθ(x0x1)]t=2TEq(xtx0)[DKL(q(xt1xt,x0)pθ(xt1xt))]DKL(q(xTx0)p(xT))\begin{aligned}L_\text{vlb} & = \mathbb{E}_{q(x_{1}|x_0)}\left[\log p_{\theta}(x_0|x_1)\right] \\ & - \sum_{t=2}^{T} \mathbb{E}_{q(x_{t}|x_0)}\left[D_{KL}(q(x_{t-1}|x_t, x_0)||p_{\theta}(x_{t-1}|x_t))\right] \\ & - D_{KL}(q(x_T|x_0)||p(x_T))\end{aligned}

Considering that the DKL(q(xTx0)p(xT))D_{KL}(q(x_T|x_0)||p(x_T)) term is assumed to be 0 in the context of Diffusion Models, you can implement this function as follows:

class VLB(Loss[Distribution]):
    def compute(self, batch: Batch[Distribution]) -> Tensor:
        t = batch.t.view(-1, *(1,) * (batch.x.ndim - 1))
        return batch.q.dkl(batch.p).where(t > 1, batch.p.nll(batch.x)).mean()

Here, batch.p and batch.q represent pθ(xt1xt)p_\theta(x_{t-1} | x_t) and q(xt1xt,x0)q(x_{t-1} | x_t, x_0), respectively. For a full list of Batch properties, check out the library’s API Reference.

On the other hand, if you wish to train your model using a hybrid loss function that is a linear combination of two or more existing functions, you can do so without creating a new Hybrid module. For instance, to combine the Simple and VLB loss functions, as proposed in Nichol & Dhariwal (2021), you can use the following syntax.

from diffusion.loss import Simple, VLB

loss = Simple(parameter="epsilon") + 0.001 * VLB()

Guidance

As of right now, ClassifierFree guidance is hardcoded into the diffusion process, and there is no way of extending the base Guidance class, unless you create your own custom Model class. You can expect this behavior to change in an upcoming release. Please refer to our official Issue Tracker for updates.


If you spot any typo or technical imprecision, please submit an issue or pull request to the library's GitHub repository .