Menu Logo

Modular Diffusion

GitHub Discord PyPI

Loss Function

The loss function of the denoising network seems to play a crucial role in the quality of the samples generated by Diffusion Models. Modular Diffusion ships with the reoccurring LsimpleL_\text{simple} and LvlbL_\text{vlb} functions, as well as a Lambda utility to build your own custom loss function.

Hybrid losses

To create a hybrid loss, simply add different loss modules together with a weight. For instance, to create a loss function that is a combination of LsimpleL_\text{simple} and LvlbL_\text{vlb}, you could write loss = Simple() + 0.001 * VLB().

Training batch

While not a loss module, the Batch object is a fundamental component of Modular Diffusion. It is used to store the data that is fed to the loss module during training. When creating custom loss modules, it is important to know the names used to refer to the different tensors stored in the Batch object, listed below.

Properties

  • w -> Initial data tensor ww.
  • x -> Data tensor after transform x0x_0.
  • y -> Label tensor yy.
  • t -> Time step tensor tt.
  • epsilon -> Noise tensor ϵ\epsilon. May be None for certain noise types.
  • z -> Latent tensor xtx_t.
  • hat -> Predicted tensor x^θ\hat{x}_\theta, ϵ^θ\hat{\epsilon}_\theta, or other(s) depending on the parametrization.
  • q -> Posterior distribution q(xt1xt,x0)q(x_{t-1}|x_t, x_0).
  • p -> Approximate posterior distribution pθ(xt1xt)p_\theta(x_{t-1} | x_t).

Lambda function

Custom loss module that is defined using a lambda function and parametrized with a distribution. It is meant to be used as shorthand for writing a custom loss function class.

Parameters

  • function -> Callable which receives a Batch object and returns a Tensor containing the loss value.

Example

from diffusion.loss import Lambda
from diffusion.distribution import Normal as N

loss = Lambda[N](lambda b: ((b.q.mu - b.p.mu)**2).mean())

Type checking

If you are using a type checker or want useful intellisense, you will need to explicitly parametrize the Lambda class with a Distribution type as seen in the example.

Simple loss function

Simple MSE loss introduced by Ho et al. (2020) in the context of Diffusion Models. Depending on the parametrization, it is defined as:

  • Lsimple=E[xx^θ2]L_\text{simple}=\mathbb{E}\left[\lvert\lvert x-\hat{x}_\theta\rvert\rvert^2\right]
  • Lsimple=E[ϵϵ^θ2]L_\text{simple}=\mathbb{E}\left[\lvert\lvert\epsilon-\hat{\epsilon}_\theta\rvert\rvert^2\right].

Parameters

  • parameter (default "x") -> Parameter to be learned and used to compute the loss. Either "x" (x^θ\hat{x}_\theta) or "epsilon" (ϵ^θ\hat{\epsilon}_\theta).
  • index (default 0) -> Index of the hat tensor which corresponds to the selected parameter.

Parametrization

If you have the option, always remember to select the same parameter both in your model’s Noise and Loss objects.

Example

from diffusion.loss import Simple

loss = Simple(parameter="epsilon")

Variational lower bound

In the context of Diffusion Models, the variational lower bound (VLB) of logp(x0)\log p(x_0) is given by:

Lvlb=Eq(x1x0)[logpθ(x0x1)]t=2TEq(xtx0)[DKL(q(xt1xt,x0)pθ(xt1xt))]DKL(q(xTx0)p(xT)),\begin{aligned}L_\text{vlb} & = \mathbb{E}_{q(x_{1}|x_0)}\left[\log p_{\theta}(x_0|x_1)\right] \\ & - \sum_{t=2}^{T} \mathbb{E}_{q(x_{t}|x_0)}\left[D_{KL}(q(x_{t-1}|x_t, x_0)||p_{\theta}(x_{t-1}|x_t))\right] \\ & - D_{KL}(q(x_T|x_0)||p(x_T))\text{,}\end{aligned}

where DKL(q(xTx0)p(xT))D_{KL}(q(x_T|x_0)||p(x_T)) is considered to be equal to 0 under standard assumptions.

Parameters

This module has no parameters.

Example

from diffusion.loss import VLB

loss = VLB()

If you spot any typo or technical imprecision, please submit an issue or pull request to the library's GitHub repository .