Menu Logo

Modular Diffusion

GitHub Discord PyPI

Data Transform

In many Diffusion Models, the diffusion process unfolds within the dataset space. However, a growing number of algorithms, like Stable Diffusion project data onto a latent space before applying diffusion. Modular Diffusion includes an Identity transform to allow you to use your data as-is, but also ships with a collection of other data transforms.


Throughout this page, we use xx rather than x0x_0 to denote the transformed data for increased readability. Any indexation to xx should be interpreted as accessing its individual elements.

Identity transform

Does not alter the input data. The transform is given by:

  • x=wx = w
  • w=xw = x.


  • w -> Input tensor ww.
  • y (default: None) -> Optional label tensor yy.
  • batch (default: 1) -> Number of samples per training batch.
  • shuffle (default: True) -> Whether to shuffle the data before each epoch.


import torch
from import Identity

w = torch.tensor([[1, 2, 3]])
data = Identity(w)
x = data.transform(next(data))
# x = tensor([[1, 2, 3]])

One-hot vector transform

Represents the input data as one-hot vectors. The transform is given by:

  • xij={1if j=wi0otherwisex_{\dots ij} =\begin{cases} 1 & \text{if } j = w_{\dots i} \\0 & \text{otherwise}\end{cases}
  • wi=argmaxj(xij)w_{\dots i} = \underset{\text{j}}{\text{argmax}}(x_{\dots ij}).


  • w -> Input tensor ww.
  • y (default: None) -> Optional label tensor yy.
  • k -> Number of categories kk.
  • batch (default: 1) -> Number of samples per training batch.
  • shuffle (default: True) -> Whether to shuffle the data before each epoch.


import torch
from import OneHot

w = torch.tensor([[0, 2, 2]])
data = OneHot(w, k=3)
x = data.transform(next(data))
# x = tensor([[[1, 0, 0],
#              [0, 0, 1],
#              [0, 0, 1]]])

Embedding space transform

Represents the input data in the embedding space. The embedding matrix is initialized with random values and updated during training. Let ERk×d\text{E} \in \mathbb{R}^{k \times d} be the embedding matrix, where kk is the number of categories and dd is the embedding dimension. Then the transform is defined as:

  • xij=Ewijx_{\dots ij} = \text{E}_{w_{\dots i}j}
  • wi=argmink(cdisti, k(xij,Ekj))w_{\dots i} = \underset{\text{k}}{\text{argmin}}\left(\underset{\text{i, k}}{\text{cdist}}\left(x_{\dots ij}, \text{E}_{kj}\right)\right).


  • w -> Input tensor ww.
  • y (default: None) -> Optional label tensor yy.
  • k -> Number of categories kk.
  • d -> Embedding dimension dd.
  • batch (default: 1) -> Number of samples per training batch.
  • shuffle (default: True) -> Whether to shuffle the data before each epoch.


import torch
from import Embedding

w = torch.tensor([[0, 2, 2]])
data = Embedding(w, k=3, d=5)
x = data.transform(next(data))
# x = tensor([[[0.201, -0.415, 0.683, -0.782, 0.039],
#              [-0.509, 0.893, 0.102, -0.345, 0.623],
#              [-0.509, 0.893, 0.102, -0.345, 0.623]]])

If you spot any typo or technical imprecision, please submit an issue or pull request to the library's GitHub repository .