Menu Logo

Modular Diffusion

GitHub Discord PyPI

Probability Distribution

In Diffusion Models, the choice of a probability distribution plays a pivotal role in modeling the noise that guides transitions between time steps. While the Distribution type is not directly used to parametrize the Model class, it is used to create custom Noise and Loss modules. Modular Diffusion provides you with a set of distribution classes you can use to create your own modules.

Parameter shapes

Distribution parameters are represented as tensors with the same size as a batch. This essentially means that a Distribution object functions as a collection of distributions, where each individual element in a batch corresponds to a unique distribution. For instance, in the case of a standard DDPM, each pixel in a batch of images is associated with its own mu and sigma values.

Normal distribution

Continuous probability distribution that is ubiquitously used in Diffusion Models. It has the following density function:

f(x)=12πσ2exp((xμ)22σ2)f(x) = \frac{1}{\sqrt{2\pi\sigma^2}}\exp\left(-\frac{(x - \mu)^2}{2\sigma^2}\right)

Sampling from a normal distribution is denoted xN(μ,σ2)x \sim \mathcal{N}(\mu, \sigma^2) and is equivalent to sampling from a standard normal distribution (μ=0\mu = 0 and σ=1\sigma = 1) and scaling the result by σ\sigma and shifting it by μ\mu:

  • ϵN(0,I)\epsilon \sim \mathcal{N}(0, \text{I})
  • x=μ+σϵx = \mu + \sigma \epsilon

Parameters

  • mu: Tensor -> Mean tensor μ\mu.
  • sigma: Tensor -> Standard deviation tensor σ\sigma. Must have the same shape as mu.

Parametrization

Please note that the sigma parameter does not correspond to the variance σ2\sigma^2, but the standard deviation σ\sigma.

Example

import torch
from diffusion.distribution import Normal as N

distribution = N(torch.zeros(3), torch.full((3,), 2))
x, epsilon = distribution.sample()
# x = tensor([ 1.1053,  1.9027, -0.2554])
# epsilon = tensor([ 0.5527,  0.9514, -0.1277])

Categorical distribution

Discrete probability distribution that separately specifies the probability of each one of kk possible categories in a vector pp. Sampling from a normal distribution is denoted xCat(p)x \sim \text{Cat}(p).

Parameters

  • p: Tensor -> Probability tensor pp. All elements must be non-negative and sum to 1 in the last dimension.

Example

import torch
from diffusion.distribution import Categorical as Cat

distribution = Cat(torch.tensor([[.1, .3, .6], [0, 0, 1]]))
x, _ = distribution.sample()
# x = tensor([[0., 1., 0.], [0., 0., 1.]])

Noise tensor

The categorical distribution returns None in place of a noise tensor ϵ\epsilon, as it would have no meaningful interpretation. Therefore, you must ignore the second return value when sampling.


If you spot any typo or technical imprecision, please submit an issue or pull request to the library's GitHub repository .