Custom Modules
When tinkering with Diffusion Models, the time will come when you need to venture beyond what the base library offers and modify the diffusion process to fit your needs. Modular Diffusion meets this requirement by providing an abstract base class for each module type, which can be extended to define custom behavior. In this tutorial, we provide an overview of each base class and an example of how to extend it.
Type annotations
As with all library code, this tutorial adheres to strict type checking standards. Although we recommend typing your code, you may elect to avoid writing type annotations. By skipping this step, however, you will not receive a warning if you try to mix incompatible modules, or other useful intellisense.
Data transform
In many Diffusion Model applications, the diffusion process takes place in the dataset space. If this is your case, the prebuilt Identity
data transform module will serve your purposes, leaving your data untouched before applying noise during training. However, a growing number of algorithms, like Stable Diffusion and Diffusion-LM, project data onto a latent space before applying diffusion.
In the case of Diffusion-LM, the dataset consists of sequences of word IDs, but the diffusion process happens in the word embedding space. This means you need a way of converting sequences of word IDs into sequences of embeddings, and train the embeddings along with the Diffusion Model. In Modular Diffusion, this can be achieved by extending the Data
base class and implement its encode
and decode
methods. The former projects the data into the latent space and the latter retrieves it to the dataset space. Let’s take a look at how you could implement the aforementioned transform:
from diffusion.base import Data
@dataclass
class Embedding(Data):
count: int = 2
dimension: int = 256
def __post_init__(self) -> None:
self.embedding = nn.Embedding(self.count, self.dimension)
def encode(self, w: Tensor) -> Tensor:
return self.embedding(w)
def decode(self, x: Tensor) -> Tensor:
return torch.cdist(x, self.embedding.weight).argmin(-1)
In the encode
method, we are transforming the input tensor w
into an embedding tensor using the learned embedding layer. The decode
method reverses this operation, by finding the most similar embedding in the embedding weight matrix to each vector in x
.
Data transforms can also be useful in cases where they have no trainable parameters. For example, the Categorical
noise module operates over one-hot vectors, which are very memory-inneficient. To mitigate this, you may store your data as a list of labels and use the OneHot
data transform module to transform it into one-hot vectors on a batch-by-batch basis, saving you a lot of memory. Or your data transform can just be a frozen variational autoencoder, like in Stable Diffusion. For further details, check out our Text Generation and Image Generation tutorials.
Noise schedule
You can implement your own custom diffusion schedule by extending the abstract Schedule
base class and implement its only abstract method, compute
. This method is responsible for providing a tensor containing the values for for . As an example, let’s implement the Linear
schedule, which is already included in the library:
from dataclasses import dataclass
from diffusion.base import Schedule
@dataclass
class Linear(Schedule):
start: float
end: float
def compute(self) -> Tensor:
return torch.linspace(self.start, self.end, self.steps + 1)
Given that steps
is already a parameter in the base class, all we need to do is define start
and end
parameters, and use them to compute the values. Then, we can initialize the schedule with the syntax Linear(steps, start, end)
.
Probability distribution
In the diffusion process, the chosen probability distribution plays a crucial role in modeling the noise that guides the transition between different states. The library comes prepackaged with a growing set of commonly used distributions, such as the Normal
distribution, but different applications or experimental setups might require you to implement your own.
To define a custom distribution, you’ll need to extend the Distribution
base class and implement three key methods: sample
, which draws a sample from the distribution and returns a tuple containing the sampled value and the applied noise (or None
if not applicable); nll
, which computes the negative log-likelihood of the given tensor x
; and dkl
, which computes the Kullback-Leibler Divergence between the distribution and another provided as other
. Take, for example, the Normal
distribution, included in the library:
@dataclass
class Normal(Distribution):
mu: Tensor
sigma: Tensor
def sample(self) -> tuple[Tensor, Tensor]:
epsilon = torch.randn(self.mu.shape, device=self.mu.device)
return self.mu + self.sigma * epsilon, epsilon
def nll(self, x: Tensor) -> Tensor:
return (0.5 * ((x - self.mu) / self.sigma)**2 + \
(self.sigma * 2.5066282746310002).log())
def dkl(self, other: Self) -> Tensor:
return (torch.log(other.sigma / self.sigma) + \
(self.sigma**2 + (self.mu - other.mu)**2) / (2 * other.sigma**2) - 0.5)
Parameter shapes
The distribution parameters are represented as tensors with the same size as a batch. This essentially means that a
Distribution
object functions as a collection of distributions, where each individual element in a batch corresponds to a unique distribution. For instance, each pixel in a batch of images is associated with its ownmu
andsigma
values.
Noise type
In most Diffusion Model applications, the standard choice of noise is Gaussian, which is already bundled within the library. However, there may be scenarios where you want to experiment with variations of the standard Gaussian noise, as in DDIM introduced in Song et al. 2020, or venture into entirely different noise types, like the one used in D3PM, introduced in Austin et al. (2021). To create your own unique noise behavior, you will need to extend the abstract Noise
base class, and implement each one of the following methods:
schedule(self, alpha: Tensor) -> None
: This method is intended for precomputing resources based on the noise schedule for . This can be beneficial for performance reasons when some calculations can be done ahead of time. A common use is calculating .stationary(self, shape: tuple[int, ...]) -> Distribution
: This method is tasked with computing the stationary distribution , i.e., the noise distribution at the final time step, given a target shape.prior(self, x: Tensor, t: Tensor) -> Distribution
: This method computes the prior distribution , i.e., the distribution of the noisy images orz
given the initial image orx
.posterior(self, x: Tensor, z: Tensor, t: Tensor) -> Distribution
: This method computes the posterior distribution , i.e., the distribution of the less noisy images given the current noisy image orz
and the initial image orx
.approximate(self, z: Tensor, t: Tensor, hat: Tensor) -> Distribution
: This method computes the approximate posterior distribution , i.e., the distribution of the less noisy images given the current noisy image orz
. This is an approximation to the true posterior distribution that is easier to sample from or compute. The tensorhat
is the output of the denoiser network containing the predicted parameters — named this way because predicted values often are denoted with a hat, e.g., .
If you aim to replicate a specific research paper, only need to translate the mathematical expressions into code. For example, the original DDPM paper yields the following equations:
where is calculated beforehand for better performance. In Modular Diffusion, here’s how we could implement this type of Gaussian noise:
from diffusion.base import Noise
from diffusion.distribution import Normal as N
@dataclass
class Gaussian(Noise[N]):
def schedule(self, alpha: Tensor) -> None:
self.alpha = alpha
self.delta = alpha.cumprod(0)
def stationary(self, shape: tuple[int, ...]) -> N:
return N(torch.zeros(shape), torch.ones(shape))
def prior(self, x: Tensor, t: Tensor) -> N:
t = t.view(-1, *(1,) * (x.dim() - 1))
return N(self.alpha[t].sqrt() * x, (1 - self.delta[t]).sqrt())
def posterior(self, x: Tensor, z: Tensor, t: Tensor) -> N:
t = t.view(-1, *(1,) * (x.dim() - 1))
mu = self.alpha[t].sqrt() * (1 - self.delta[t - 1]) * z
mu += self.delta[t - 1].sqrt() * (1 - self.alpha[t]) * x
mu /= (1 - self.delta[t])
sigma = (1 - self.alpha[t]) * (1 - self.delta[t - 1]) / (1 - self.delta[t])
sigma = sigma.sqrt()
return N(mu, sigma)
def approximate(self, z: Tensor, t: Tensor, hat: Tensor) -> N:
t = t.view(-1, *(1,) * (z.dim() - 1))
mu = (z - (1 - self.alpha[t]) / (1 - self.delta[t]).sqrt() * hat[0])
mu /= self.alpha[t].sqrt()
sigma = (1 - self.alpha[t]) * (1 - self.delta[t - 1]) / (1 - self.delta[t])
sigma = sigma.sqrt()
return N(mu, sigma)
Broadcasting
You will notice that some methods start with a statement that reshapes the tensor
t
. This only done to allow broadcasting of the tensors in the subsequent operations. For instance, in theprior
method, we need to multiplyself.alpha[t].sqrt()
byx
, butself.alpha
has shape[t]
andx
has shape[b, c, h, w]
. By reshapingt
to[b, 1, 1, 1]
, we can multiplyself.alpha[t].sqrt()
byx
without any issues.
The schedule
method precomputes alpha
and delta
(cumulative product of alpha
) values, which are used in the other methods. The stationary
method defines the initial noise distribution, while prior
, posterior
, and approximate
methods implement the corresponding mathematical equations for the prior, posterior, and approximate posterior distributions. Collectively, these methods define the complete Gaussian noise model from the original DDPM paper. Note that it is possible to achieve a more efficient solution by precomputing some of the recurrent expressions used in the methods.
Denoiser neural network
Modular Diffusion comes with general-use UNet
and Transformer
classes, which have proven to be effective denoising networks in the context of Diffusion Models. However, it is not uncommon to see authors make modifications to these networks to achieve even better results. To design your own original network, extend the base abstract Net
class. This class acts as only a thin wrapper over the standard Pytorch nn.Module
class, meaning you can use it exactly the same way. The forward
method should take three tensor arguments: the noisy input x
, the conditioning matrix y
, and the diffusion time steps t
.
Network output shape
When creating your neural network, it’s important to remember that the first dimension of its output will be interpreted as the parameter index, irrespective of the number of parameters being predicted. For instance, if your network is predicting both the mean and variance of noise in an image, the output shape should be
[2, c, h, w]
. But even if you’re predicting only the mean, the shape should be[1, c, h, w]
— not[c, h, w]
.
In scenarios where your network requires only a post-processing step, such as applying a Softmax
function, there’s no need to create an entirely new network class. Modular Diffusion allows for a more concise approach using the pipe operator, as shown in the Getting Started tutorial:
from diffusion.net import Transformer
from torch.nn import Softmax
net = Transformer(input=512) | Softmax(3)
Loss function
In each training step, your Model
instance creates a Batch
object, which contains all the information you need about the current batch to compute the corresponding loss. To create a custom loss function, you can extend from the Loss
base class and implement the compute
method, where the loss is calculated based on the current batch. Let’s start by implementing introduced in Ho et al. 2020. The formula for this loss function is , where is the noise added and is the predicted noise.
from diffusion.base import Distribution, Loss
class Simple(Loss[Distribution]):
def compute(self, batch: Batch[Distribution]) -> Tensor:
return ((batch.epsilon - batch.hat[0])**2).mean()
Notice how we parametrize the Loss
and Batch
classes with the Distribution
type. This just tells your IDE you can use this loss class for any kind of distribution. If you’d like to make a loss function that is only compatible with, say, Normal
distributions, you should specify this inside the square brackets. Another thing to note is how we assume that the first parameter in the denoiser neural network output hat
(named this way because predictions are often denoted with a little hat) is . You can alter this behavior by changing the index or even make it parametrizable with a class property.
In certain scenarios, you might need not to compute your loss using batch.hat
directly but instead utilize the approximate posterior distribution , which itself is estimated from batch.hat
in the Noise
module. This is the case when you need to compute the variational lower bound (VLB), the original loss function utilized to train Diffusion Models. The formula for the VLB is expressed as:
Considering that the term is assumed to be 0 in the context of Diffusion Models, you can implement this function as follows:
class VLB(Loss[Distribution]):
def compute(self, batch: Batch[Distribution]) -> Tensor:
t = batch.t.view(-1, *(1,) * (batch.x.ndim - 1))
return batch.q.dkl(batch.p).where(t > 1, batch.p.nll(batch.x)).mean()
Here, batch.p
and batch.q
represent and , respectively. For a full list of Batch
properties, check out the library’s API Reference.
On the other hand, if you wish to train your model using a hybrid loss function that is a linear combination of two or more existing functions, you can do so without creating a new Hybrid
module. For instance, to combine the Simple
and VLB
loss functions, as proposed in Nichol & Dhariwal (2021), you can use the following syntax.
from diffusion.loss import Simple, VLB
loss = Simple(parameter="epsilon") + 0.001 * VLB()
Guidance
As of right now, ClassifierFree
guidance is hardcoded into the diffusion process, and there is no way of extending the base Guidance
class, unless you create your own custom Model
class. You can expect this behavior to change in an upcoming release. Please refer to our official Issue Tracker for updates.
If you spot any typo or technical imprecision, please submit an issue or pull request to the library's GitHub repository .