Menu Logo

Modular Diffusion

GitHub Discord PyPI

Denoising Network

The backbone of Diffusion Models is a denoising network, which is trained to gradually denoise data. While earlier works used a U-Net architecture, newer research has shown that Transformers can be used to achieve comparable or superior results. Modular Diffusion ships with both types of denoising network. Both are implemented in Pytorch and thinly wrapped in a Net module.

Future warning

The current denoising network implementations are not necessarily the most efficient or the most effective and are bound to change in a future release. They do, however, provide a great starting point for experimentation.

U-Net

U-Net implementation adapted from the The Annotated Diffusion Model. It takes an input with shape [b, c, h, w] and returns an output with shape [p, b, c, h, w].

Parameters

  • channels -> Sequence of integers representing the number of channels in each layer of the U-Net.
  • labels (default 0) -> Number of unique labels in yy.
  • parameters (default 1) -> Number of output parameters p.
  • hidden (default 256) -> Hidden dimension.
  • heads (default 8) -> Number of attention heads.
  • groups (default 16) -> Number of groups in the group normalization layers.

Example

from diffusion.net import UNet

net = UNet(channels=(3, 64, 128, 256), labels=10)

Transformer

Transformer implementation adapted from the Peebles & Xie (2022) (adaptive layer norm block). It takes an input with shape [b, l, e] and returns an output with shape [p, b, l, e].

Parameters

  • input -> Input embedding dimension e.
  • labels (default 0) -> Number of unique labels in yy.
  • parameters (default 1) -> Number of output parameters p.
  • depth (default 256) -> Number of transformer blocks.
  • width (default 256) -> Hidden dimension.
  • heads (default 8) -> Number of attention heads.

Example

from diffusion.net import Transformer

net = Transformer(input=x.shape[2])

If you spot any typo or technical imprecision, please submit an issue or pull request to the library's GitHub repository .