Building a Diffusion Model From Scratch

The rise of Generative AI has been nothing short of explosive. Just a few years ago, generating high-quality images was the domain of complex and often unstable Generative Adversarial Networks (GANs). Then came diffusion models, a paradigm shift that has since dominated the generative landscape. These models, such as DALL-E, Midjourney, and Stable Diffusion, have transformed text prompts into stunning visual art. So, if you want to learn building a diffusion model from scratch, this article is for you. In this article, I’ll take you through a step-by-step tutorial on building a diffusion model from scratch.

But what exactly is a Diffusion Model?

At its core, it’s a model that learns to reverse a process of gradual destruction. Think of it like this: You have a clear, high-quality image of a dog. You then progressively add a little bit of noise, like static on a TV screen, over many small steps until the image is completely unrecognizable, just pure noise. That’s the forward diffusion process. Now, the reverse diffusion process begins. The model learns to reverse each of those tiny noise-adding steps, one by one, to turn the pure noise back into the original clear image.

But what exactly is a Diffusion Model?

The brilliance of this approach is in its simplicity and stability. Instead of trying to generate a complex image from scratch in one go, the model learns to solve a much simpler problem at each step: “What noise do I need to remove to get from this slightly noisy image to the one before it?”

This step-by-step denoising process is what enables diffusion models to produce high-fidelity, diverse, and stable results.

Building a Diffusion Model From Scratch

Before we can build anything, we need to prepare our workspace. We’ll define our hyperparameters and set up our data pipeline:

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm
import os

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 128
IMG_SIZE = 32 # We'll resize MNIST images to 32x32
TIMESTEPS = 300 # Number of steps in the diffusion process
LEARNING_RATE = 1e-3
EPOCHS = 20
OUTPUT_DIR = "diffusion_outputs"

os.makedirs(OUTPUT_DIR, exist_ok=True)

def get_data_loader():
    """
    Prepares and returns the MNIST DataLoader.
    """
    # Define transformations for the images
    # 1. Resize to IMG_SIZE
    # 2. Convert to Tensor
    # 3. Normalize to [-1, 1] range
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: (t * 2) - 1)
    ])

    dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)

    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    return dataloader

Here, we first defined a few key parameters, such as BATCH_SIZE (the number of images to process simultaneously) and TIMESTEPS (the number of steps in our diffusion process). The DEVICE is also set to “cuda” if a GPU is available, which is crucial for speeding up training.

We created a function called get_data_loader() that handles loading and transforming our dataset, specifically the MNIST dataset. The key transformations are resizing the images to 32×32 and normalizing the pixel values from the standard [0, 255] range to [-1, 1]. This normalization is a common practice in deep learning and helps the model learn more effectively.

Step 1: The Forward Process (Adding the Noise)

This is where we implement the first half of the diffusion model, adding noise. While we could iteratively add noise over hundreds of steps, there’s a neat mathematical trick to calculate the noisy image at any given timestep directly:

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    """
    Creates a linear variance schedule.

    Args:
        timesteps (int): The number of timesteps.

    Returns:
        torch.Tensor: A tensor of beta values.
    """
    return torch.linspace(start, end, timesteps)


# Get the beta schedule
betas = linear_beta_schedule(timesteps=TIMESTEPS).to(DEVICE)

# Calculate alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0) # Cumulative product of alphas
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) # Previous cumulative product

# Calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# Calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)


def extract(a, t, x_shape):
    """
    Extracts the values from 'a' at the indices 't' and reshapes it to
    match the batch dimension of 'x'.
    """
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(DEVICE)


def q_sample(x_start, t, noise=None):
    """
    Forward diffusion process: adds noise to an image.

    Args:
        x_start (torch.Tensor): The initial image (x_0).
        t (torch.Tensor): The timestep index.
        noise (torch.Tensor, optional): The noise to add. If None, generated randomly.

    Returns:
        torch.Tensor: The noised image at timestep t.
    """
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape)

    # Equation for noising: x_t = sqrt(alpha_cumprod_t) * x_0 + sqrt(1 - alpha_cumprod_t) * noise
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

Here, we defined a linear_beta_schedule that creates a sequence of noise levels (betas) for each timestep. The beta values are small and increase linearly over time, ensuring the noise is added gradually.

The code then pre-calculates several values derived from the betas, such as alphas_cumprod (the cumulative product of alphas), which are essential for our main diffusion formula. This is a standard optimization to enhance the process’s efficiency.

The function q_sample(x_start, t, noise) takes a clean image x_start and a timestep t, and applies the pre-calculated noise to it. The formula used here, x_t = sqrt(alpha_cumprod_t) * x_0 + sqrt(1 – alpha_cumprod_t) * noise, enables us to skip directly to the noisy image at step t without having to process every previous step.

Step 2: Building the Brain of the Model (The U-Net)

The model we use to reverse the noise is a special type of neural network called a U-Net. A U-Net is perfect for this task because it’s designed to take an image as input and output an image of the same size. It’s called a U-Net because of its U-shaped architecture, which has a contracting path (downsampling) and an expansive path (upsampling) with skip connections in between. Let’s build the brain of our model:

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = torch.log(torch.tensor(10000.0, device=device)) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()

    def forward(self, x, t):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(...,) + (None,) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Up sample
        return self.transform(h)

class SimpleUnet(nn.Module):
    def __init__(self):
        super().__init__()
        image_channels = 1 # MNIST is grayscale
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 1
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU()
        )

        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsampling
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsampling
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)

A crucial detail is that our model needs to know what timestep it’s currently on, so it knows how much noise to remove. We can’t just feed the timestep number (e.g., 250) directly. Instead, we use SinusoidalPositionEmbeddings to convert the integer timestep into a meaningful vector representation.

The Block Class is the fundamental building block of our U-Net. It contains convolutional layers, batch normalization, and an essential time-MLP layer that processes the timestep embedding. The forward method illustrates how the time embedding is incorporated into the image data, guiding the network’s denoising process.

The SimpleUnet Class ties everything together. It defines the complete U-Net architecture, which consists of a series of downsampling blocks, followed by upsampling blocks. The skip connections are handled by a residual_inputs list, which saves the outputs from the downsampling path and concatenates them with the corresponding upsampling outputs. This is what allows the U-Net to preserve fine-grained details during the denoising process.

Step 3: The Reverse Process (Denoising)

This is the generation part! The reverse process utilizes our trained U-Net to transform pure noise into a generated image. This is often referred to as sampling:

@torch.no_grad()
def p_sample(model, x, t, t_index):
    """
    Performs one step of the reverse diffusion process (sampling).

    Args:
        model (nn.Module): The U-Net model.
        x (torch.Tensor): The current noisy image (x_t).
        t (torch.Tensor): The current timestep.
        t_index (int): The index of the current timestep.

    Returns:
        torch.Tensor: The de-noised image for the previous timestep (x_{t-1}).
    """
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)

    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise

@torch.no_grad()
def p_sample_loop(model, shape):
    """
    The full sampling loop to generate an image from noise.

    Args:
        model (nn.Module): The U-Net model.
        shape (tuple): The shape of the image to generate (e.g., [batch_size, channels, H, W]).

    Returns:
        torch.Tensor: The final generated image.
    """
    img = torch.randn(shape, device=DEVICE)
    imgs = []

    for i in tqdm(reversed(range(0, TIMESTEPS)), desc="Sampling loop", total=TIMESTEPS):
        t = torch.full((shape[0],), i, device=DEVICE, dtype=torch.long)
        img = p_sample(model, img, t, i)
        # Optional: save intermediate steps
        # if i % 50 == 0:
        #     imgs.append(img.cpu())
    return img

Here, the p_sample function performs a single step of the reverse diffusion process, going from x_t to x_{t-1} (a slightly less noisy image). It uses our trained U-Net model to predict the noise that needs to be removed at that step. The function then uses the predicted noise to calculate the de-noised image for the previous timestep, gradually cleaning up the image.

The p_sample_loop function is the complete sampling loop. It starts with a tensor of pure random noise (torch.randn) and iterates backwards from the last timestep (TIMESTEPS) down to zero. In each iteration, it calls p_sample to take a tiny step toward a clearer image. Ultimately, you are left with a brand-new, generated image.

Step 4: Training the Model

With all the pieces in place, it’s time to train our model. The goal of training is to teach our U-Net to predict the noise at any given timestep accurately:

def get_loss(model, x_start, t):
    """
    Calculates the loss for a given batch.
    """
    # 1. Generate random noise
    noise = torch.randn_like(x_start)

    # 2. Get the noised image at timestep t
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)

    # 3. Get the model's noise prediction
    predicted_noise = model(x_noisy, t)

    # 4. Calculate the loss between the actual noise and predicted noise
    loss = F.l1_loss(noise, predicted_noise) # L1 loss is common and works well

    return loss

def train():
    dataloader = get_data_loader()
    model = SimpleUnet().to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    for epoch in range(EPOCHS):
        # Use tqdm for a nice progress bar
        loop = tqdm(dataloader, leave=True)
        for batch_idx, (images, _) in enumerate(loop):
            images = images.to(DEVICE)

            # Sample a random timestep for each image in the batch
            t = torch.randint(0, TIMESTEPS, (BATCH_SIZE,), device=DEVICE).long()

            # Calculate loss
            loss = get_loss(model, images, t)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update progress bar
            loop.set_description(f"Epoch [{epoch+1}/{EPOCHS}]")
            loop.set_postfix(loss=loss.item())

        # --- After each epoch, generate and save a sample image ---
        print(f"Epoch {epoch+1} completed. Generating sample image...")
        num_images_to_sample = 16
        sample_shape = (num_images_to_sample, 1, IMG_SIZE, IMG_SIZE)

        # Generate the images
        generated_images = p_sample_loop(model, sample_shape)

        # Denormalize from [-1, 1] to [0, 1] for saving
        generated_images = (generated_images + 1) * 0.5

        # Save the image grid
        save_image(generated_images, os.path.join(OUTPUT_DIR, f"epoch_{epoch+1}_sample.png"), nrow=4)
        print(f"Sample image saved for epoch {epoch+1}.")

    # Save the final model
    torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "diffusion_mnist.pth"))
    print("Training finished and model saved.")
    
if __name__ == "__main__":
    print("Starting Diffusion Model Training...")
    print(f"Device: {DEVICE}")
    print(f"Timesteps: {TIMESTEPS}")
    print(f"Batch Size: {BATCH_SIZE}")

    # Run the training process
    train()

    print("All steps complete. You can find generated images and the saved model in the 'diffusion_outputs' directory.")    
Starting Diffusion Model Training...
Device: cpu
Timesteps: 300
Batch Size: 128
100%|██████████| 9.91M/9.91M [00:00<00:00, 37.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.10MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 8.58MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.39MB/s]
Epoch [1/20]: 66%|██████▌ | 310/469 [2:29:46<1:14:55, 28.27s/it, loss=0.11]

It will take some hours to execute.

The script will create a data folder for the MNIST dataset and a diffusion_outputs folder. Inside diffusion_outputs, you will see sample images being saved after each epoch. Additionally, the final trained model, diffusion_mnist.pth, will be present at the end.

The get_loss function is our objective function. We start with a random clean image (x_start), add noise to it at a random timestep t to get x_noisy, and then feed x_noisy and t into our U-Net. The U-Net will predict the noise, and our loss is simply the difference between the noise we actually added and the noise our model predicted. We use L1 loss for this, which measures the mean absolute error.

The training loop iterates through the dataset for a specified number of epochs. In each iteration, it performs the following:

  1. Loads a batch of images.
  2. Sample a random timestep t for each image in the batch.
  3. Calculates the loss using get_loss.
  4. Performs backpropagation and updates the model’s weights using the optimizer.

After each epoch, the code generates and saves a sample image using the p_sample_loop to show the model’s progress. This is the moment you see the model’s creativity come to life!

Final Words

Building a diffusion model from scratch is more than just a coding exercise. You’ve now seen how a seemingly complex task like generating an image can be broken down into a series of simple, predictable steps. The process of adding noise and then learning to remove it is an elegant and powerful concept that unlocks new creative possibilities every day.

I hope you liked this article on building a diffusion model from scratch. Feel free to ask valuable questions in the comments section below. You can follow me on Instagram for many more resources.

Aman Kharwal
Aman Kharwal

AI/ML Engineer | Published Author. My aim is to decode data science for the real world in the most simple words.

Articles: 2026

Leave a Reply

Discover more from AmanXai by Aman Kharwal

Subscribe now to keep reading and get access to the full archive.

Continue reading