Training GANs on MNIST

17th Dec 2024

Generative Adversarial Networks

Following on from my previous article here, today we will be exploring generative adversarial networks (GANs). As opposed to discrimanitive models which able to complete classification or regression tasks for us, generative models are able to synthesize new data which resembles the training data. GANs are a simple and interesting type of model, and in this case we will use them to generate new digits based on the MNIST data set.

GAN Architecture

As you can see, a GAN is made up of two models actually:

Generative Adversarial Networks

The generator GG starts off with noise zz, and uses it to generate some G(z)G(z). The result of this is passed into another model, the discriminator. The discriminator is trained on the real and generated entries, and its job is to decide if what it's given is real or fake.

This is then fed back into the generator, which results in the generator iteratively producing more and more "realistic" looking images. When the discriminator is no longer able to tell the difference between the generated and real images, we know we're done.

Like in my article here, I will start with a MLP and then add in convolutional layers, which will greatly improve our results as you will soon see:

I start with imports and params:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
import os
import matplotlib.pyplot as plt

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyperparameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
learning_rate = 0.0002

Then I load in my data same as last time:

dataset = dsets.MNIST(root='data/',
                      train=True,
                      transform=transforms.ToTensor(),
                      download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size,
                                          shuffle=True)

and then I define my Generator model:

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        out = self.tanh(out)
        return out

the generator GG will end up taking in noise zz and turning it into images.

We create our generator, and we define our discriminator:

G = Generator(latent_size, hidden_size, image_size).to(device)
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        out = self.sigmoid(out)
        return out

# Create discriminator
D = Discriminator(image_size, hidden_size, 1).to(device)

Now we define our optimisers, and output dir:

# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = optim.Adam(D.parameters(), lr=learning_rate)
g_optimizer = optim.Adam(G.parameters(), lr=learning_rate)
# Create a directory to save generated images
if not os.path.exists('gan_images'):
    os.makedirs('gan_images')

Finally, we can train it by doing the following:

# Training the GAN
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        # Load a batch & reshape + flatten images
        images = images.reshape(batch_size, -1).to(device)

        # Create labels for real and fake images
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Train the discriminator
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs

        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # Train the generator
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)

        g_loss = criterion(outputs, real_labels)
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i+1) % 200 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}, D(x): {real_score.mean().item():.2f}, D(G(z)): {fake_score.mean().item():.2f}')

    # Save real and fake images
    if (epoch+1) == 1 or (epoch+1) % 20 == 0:
        with torch.no_grad():
            fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
            save_image(fake_images, os.path.join('gan_images', f'fake_images-{epoch+1}.png'))

so as you can see, we have trained our GAN on random noise, let's check out the results:

GAN results 200 epochs

They certainly look like hand written digits, but they suck! There's lots of noise in the result and some don't have a digit at all!

Let's improve our model by adding convolutions:

class Generator(nn.Module):
    def __init__(self, input_size):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(input_size, 256, 7, 1, 0, bias=False),  # Output: (256, 7, 7)
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),  # Output: (128, 14, 14)
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),  # Output: (64, 28, 28)
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 3, 1, 1, bias=False),  # Output: (1, 28, 28)
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x.view(x.size(0), -1, 1, 1))

and the discriminator to be defined as:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 64, 3, 2, 1, bias=False),  # Output: (64, 14, 14)
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2),  # Output: (64, 7, 7)
            nn.Conv2d(64, 128, 3, 2, 1, bias=False),  # Output: (128, 4, 4)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.MaxPool2d(2),  # Output: (128, 2, 2)
            nn.Conv2d(128, 256, 3, 2, 1, bias=False),  # Output: (256, 1, 1)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, 1, 1, 0, bias=False),  # Output: (1, 1, 1)
            nn.Flatten(),  # Flatten the tensor before the linear layer
            nn.Linear(1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

We hope here that first of all, the discriminator will be more accurate at deciding if an image is real or fake, since it's a "smarter" model (with convolutional layers). We also hope that by adding convolutional layers, the generator will learn the underlying properties of the image that we're trying to have it generate.

Now, when we re-train we get:

GAN v2 results 200 epochs

That looks awesome! We wanted a GAN which generates good looking images that look like MNIST digits, and we have exactly that! The amazing thing here is that these are completely new images, that don't exist anywhere in our dataset.

Well, thank you for reading, and I hope you found this to be a gentle and palatable introduction to generative ML, see you next time!