Loss Function
The loss function of the denoising network seems to play a crucial role in the quality of the samples generated by Diffusion Models. Modular Diffusion ships with the reoccurring and functions, as well as a Lambda
utility to build your own custom loss function.
Hybrid losses
To create a hybrid loss, simply add different loss modules together with a weight. For instance, to create a loss function that is a combination of and , you could write
loss = Simple() + 0.001 * VLB()
.
Training batch
While not a loss module, the Batch
object is a fundamental component of Modular Diffusion. It is used to store the data that is fed to the loss module during training. When creating custom loss modules, it is important to know the names used to refer to the different tensors stored in the Batch
object, listed below.
Properties
w
-> Initial data tensor .x
-> Data tensor after transform .y
-> Label tensor .t
-> Time step tensor .epsilon
-> Noise tensor . May beNone
for certain noise types.z
-> Latent tensor .hat
-> Predicted tensor , , or other(s) depending on the parametrization.q
-> Posterior distribution .p
-> Approximate posterior distribution .
Lambda function
Custom loss module that is defined using a lambda function and parametrized with a distribution. It is meant to be used as shorthand for writing a custom loss function class.
Parameters
function
-> Callable which receives aBatch
object and returns aTensor
containing the loss value.
Example
from diffusion.loss import Lambda
from diffusion.distribution import Normal as N
loss = Lambda[N](lambda b: ((b.q.mu - b.p.mu)**2).mean())
Type checking
If you are using a type checker or want useful intellisense, you will need to explicitly parametrize the
Lambda
class with aDistribution
type as seen in the example.
Simple loss function
Simple MSE loss introduced by Ho et al. (2020) in the context of Diffusion Models. Depending on the parametrization, it is defined as:
- .
Parameters
parameter
(default"x"
) -> Parameter to be learned and used to compute the loss. Either"x"
() or"epsilon"
().index
(default0
) -> Index of thehat
tensor which corresponds to the selectedparameter
.
Parametrization
If you have the option, always remember to select the same parameter both in your model’s
Noise
andLoss
objects.
Example
from diffusion.loss import Simple
loss = Simple(parameter="epsilon")
Variational lower bound
In the context of Diffusion Models, the variational lower bound (VLB) of is given by:
where is considered to be equal to 0 under standard assumptions.
Parameters
This module has no parameters.
Example
from diffusion.loss import VLB
loss = VLB()
If you spot any typo or technical imprecision, please submit an issue or pull request to the library's GitHub repository .