Menu Logo

Modular Diffusion

GitHub Discord PyPI

Diffusion Model

In Modular Diffusion, the Model class is a high-level interface that allows you to easily design and train your own custom Diffusion Models. It acts essentially as a container for all the modules that make up a Diffusion Model.

Parameters

  • data -> Data transform module.
  • schedule -> Noise schedule module.
  • noise -> Noise type module.
  • net -> Denoising network module.
  • loss -> Loss function module.
  • guidance (Default: None) -> Optional guidance module.
  • optimizer (Default: partial(Adam, lr=1e-4)) -> Pytorch optimizer constructor function.
  • device (Default: "cpu") -> Device to train the model on.
  • compile (Default: true) -> Whether to compile the model with torch.compile for faster training.

Example

import diffusion
from diffusion.data import Identity
from diffusion.guidance import ClassifierFree
from diffusion.loss import Simple
from diffusion.net import UNet
from diffusion.noise import Gaussian
from diffusion.schedule import Cosine
from torch.optim import AdamW
from functools import partial

model = diffusion.Model(
    data=Identity(x, y, batch=128, shuffle=True),
    schedule=Cosine(steps=1000),
    noise=Gaussian(parameter="epsilon", variance="fixed"),
    net=UNet(channels=(1, 64, 128, 256), labels=10),
    loss=Simple(parameter="epsilon"),
    guidance=ClassifierFree(dropout=0.1, strength=2),
    optimizer=partial(AdamW, lr=3e-4),
    device="cuda" if torch.cuda.is_available() else "cpu",
)

Train the model

Model.train trains the model for a specified number of epochs. It returns a generator that yields the current loss when each epoch is finished, allowing the user to easily validate the model between epochs inside a for loop.

Parameters

  • epochs (default: 1) -> Number of epochs to train the model.
  • progress (default: True) -> Whether to display a progress bar for each epoch.

Examples

# Train model without validation
losses = [*model.train(epochs=100)]
# Train model with validation
for epoch, loss in enumerate(model.train(epochs=100)):
    if epoch % 10 == 0:
        # Validate your model here
        model.save("model.pt")

Sample from the model

Model.sample samples from the model for a specified batch size and label tensor. It returns a tensor with shape [t, b, ...] where t is the number of time steps, b is the batch size, and ... represents the shape of the data. This allows the user to visualize the sampling process.

Parameters

  • y (default: None) -> Optional label tensor yy to condition sampling.
  • batch (default: 1) -> Number of samples to generate. If y is not None, this is the number of samples per label.
  • progress (default: True) -> Whether to display a progress bar.

Examples

# Save only final sampling results
*_, z = model.sample(batch=10)
# Save entire sampling process
z = model.sample(batch=10)

Load the model

Model.load loads the model’s trainable weights from a file. The model should be initialized with the same trainable modules it was initially trained with. If a trainable module is replaced with a different module, the model will not load correctly.

Parameters

  • path -> Path to the file containing the model’s weights.

Example

import diffusion
from pathlib import Path

model = diffusion.Model(...)
if Path("model.pt").exists()
	model.load("model.pt")

Save the model

Model.save saves the model’s trainable weights to a file.

Parameters

  • path -> Path to the file to save the model’s weights to.

Example

model.save("model.pt")

If you spot any typo or technical imprecision, please submit an issue or pull request to the library's GitHub repository .