Generative AI Model From Scratch with Python

A Machine Learning model is used to make predictions based on historical data. Similarly, a Generative AI model creates new original data by learning from the historical data. So, if you want to learn how to build a Generative AI model from scratch, this article is for you. In this article, I’ll take you through the task of building a Generative AI model from scratch with Python.

How to Build a Generative AI Model?

Building a generative AI model from scratch involves a dual-network setup consisting of a generator and a discriminator, unlike a typical machine learning model that usually involves a single predictive model.

In generative AI, the generator creates new data samples from random noise, while the discriminator evaluates these samples against real data to classify them as real or fake. The two networks are trained in tandem through an adversarial process where the generator aims to improve its ability to produce realistic outputs, and the discriminator enhances its accuracy in distinguishing between genuine and generated data.

In this article, I’ll take you through building a Generative AI model using Generative Adversarial Networks (GANs). So, let’s understand what GANs are.

Introducing GANs

Generative Adversarial Networks (GANs) consist of two neural networks:

  1. Generator: Generates new data samples.
  2. Discriminator: Evaluates whether a given data sample is real (from the training data) or fake (generated by the generator).
GANs

The two networks are trained together in a zero-sum game: the generator tries to fool the discriminator, while the discriminator aims to accurately distinguish real from fake data.

A GAN consists of the following key components:

  • Noise Vector: A random input vector fed into the generator.
  • Generator: A neural network that transforms the noise vector into a data sample.
  • Discriminator: A neural network that classifies input data as real or fake.

Getting Started with Building a Generative AI Model

In this article, we will use the MNIST dataset for building a generative AI model from scratch because of these three reasons:

  1. the data is large enough to create a generative AI model
  2. loading the data is easy
  3. training a generative AI model on this data will be possible for you, given the computational power of Google Colab (that most beginners use as students)

So, we will train a Generative AI model to generate images. Let’s import the necessary Python libraries and the dataset to get started with the task:

import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import BatchNormalization, LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam

import ssl
import urllib.request

ssl._create_default_https_context = ssl._create_unverified_context

(X_train, _), (_, _) = mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 [==============================] - 3s 0us/step

Now, we will build a generator network. The generator network transforms a random noise vector into a data sample. We’ll use a simple feed-forward network with several layers:

def build_generator():
    model = Sequential()
    model.add(Dense(256, input_dim=100))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(784, activation='tanh'))
    model.add(Reshape((28, 28, 1)))
    return model

generator = build_generator()

Our starts with a dense (fully connected) layer that takes a 100-dimensional noise vector as input. It passes through several layers, each comprising dense connections followed by a LeakyReLU activation function to introduce non-linearity, and BatchNormalization layers to stabilize training and improve convergence. The network then progressively increases the number of neurons, which culminates in a dense layer with 784 neurons, corresponding to the flattened pixel values of a 28×28 image.

This final output is reshaped to match the original image dimensions, with a ‘tanh’ activation function to output values in the range [-1, 1], suitable for image data. This setup enables the generator to transform random noise into structured image data that mimics the real data distribution.

Now, we will build the discriminator network. The discriminator network will classify input images as real or fake. It will be a binary classifier that outputs the probability of an image being real:

def build_discriminator():
    model = Sequential()
    model.add(Flatten(input_shape=(28, 28, 1)))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    return model

discriminator = build_discriminator()
discriminator.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy', metrics=['accuracy'])

The discriminator network begins with a Flatten layer, which converts the 28×28 pixel images into a one-dimensional array to prepare it for fully connected layers. It then passes through two dense layers, with 512 and 256 neurons respectively, each followed by a LeakyReLU activation function to introduce non-linearity. The final dense layer has a single neuron with a ‘sigmoid’ activation, which outputs a probability score indicating whether the input image is real or fake.

The model is then compiled with the Adam optimizer and binary crossentropy loss function to optimize the network to accurately distinguish real images from those generated by the generator.

Compiling and Training the GAN to build a Generative AI Model to Generate Images

Now, we will combine the generator and discriminator to train our Generative AI model to generate images. During training, the discriminator will be trained to distinguish real from fake images, while the generator will be trained to produce images that fool the discriminator:

discriminator.trainable = False

gan_input = Input(shape=(100,))
generated_image = generator(gan_input)
gan_output = discriminator(generated_image)

gan = Model(gan_input, gan_output)
gan.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy')

def train_gan(epochs, batch_size=128):
    X_train, _ = mnist.load_data()
    X_train = (X_train[0].astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=3)
    
    real = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    for epoch in range(epochs):
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_images = X_train[idx]
        
        noise = np.random.normal(0, 1, (batch_size, 100))
        generated_images = generator.predict(noise)
        
        d_loss_real = discriminator.train_on_batch(real_images, real)
        d_loss_fake = discriminator.train_on_batch(generated_images, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        noise = np.random.normal(0, 1, (batch_size, 100))
        g_loss = gan.train_on_batch(noise, real)
        
        if epoch % 100 == 0:
            print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100*d_loss[1]}] [G loss: {g_loss}]")
            save_images(epoch)

def save_images(epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, 100))
    generated_images = generator.predict(noise)
    
    generated_images = 0.5 * generated_images + 0.5
    
    fig, axs = plt.subplots(r, c)
    count = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(generated_images[count, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            count += 1
    fig.savefig(f"gan_images_{epoch}.png")
    plt.close()

train_gan(epochs=10000, batch_size=64)

In the above code, we are defining and training a Generative Adversarial Network by combining a generator and a discriminator into a single model.

The gan_input represents the random noise fed into the generator, which produces a generated_image. This image is then passed to the discriminator, which outputs a probability (gan_output) indicating whether the image is real or fake.

The discriminator’s weights are set to non-trainable during this process to ensure that only the generator learns from the adversarial feedback. The GAN is trained using the function train_gan, where the discriminator first learns to distinguish between real images and fake images generated by the generator, and then the generator is updated to produce more convincing fake images.

The loss functions guide this adversarial process, where the generator aims to minimize the discriminator’s ability to detect fakes, which results in progressively more realistic generated images. The save_images function periodically saves these generated images to visualize the training progress.

The model generated several images. Look at these three images in the output below:

image generated from the generative ai model using gans 1
A sample image from the outputs of the early stages (Epoch 0)
image generated from the generative ai model using gans 2
A sample image from the outputs of the intermediate stages (Epochs 100, 200, … 9000)
image generated from the generative ai model using gans 3
A sample image from the outputs of the later stages (Epochs 9100 to 9900)

The first image represents images of the early outputs. Initially, the images appear as random noise without any discernible patterns. It represents that the generator wasn’t able to learn how to produce meaningful outputs during this training phase.

The second image represents images from the intermediate output. It shows that as training progresses, the generator starts producing outputs that begin to resemble the structure of handwritten digits. Although some images still appear noisy or indistinct, there’s a noticeable shift towards more defined shapes and features. During these stages, the generator and discriminator are in a competitive phase where both networks are improving. The generator tries to create more realistic images, while the discriminator enhances its ability to distinguish between real and generated samples.

The third image represents images from the later stages of the output generated by the model. The images show a significant improvement, with many outputs clearly resembling real MNIST digits. The details of the digits are more defined, and the shapes are more accurate, which reflects the generator’s increased capability to capture the distribution of the training data.

So, this is how we can build a Generative AI model from scratch using Python with GANs.

Summary

So, in generative AI, the generator creates new data samples from random noise, while the discriminator evaluates these samples against real data to classify them as real or fake. The two networks are trained in tandem through an adversarial process where the generator aims to improve its ability to produce realistic outputs, and the discriminator enhances its accuracy in distinguishing between genuine and generated data.

I hope you liked this article on how to build a Generative AI model from scratch using Python. Feel free to ask valuable questions in the comments section below. You can follow me on Instagram for many more resources.

Aman Kharwal
Aman Kharwal

AI/ML Engineer | Published Author. My aim is to decode data science for the real world in the most simple words.

Articles: 2012

One comment

Leave a Reply

Discover more from AmanXai by Aman Kharwal

Subscribe now to keep reading and get access to the full archive.

Continue reading