Project Analysis
image

GAN Model for MNIST Digit Generation

About Project

Generating realistic handwritten digit images requires a model capable of learning complex data distributions from limited training samples. Traditional image generation methods often lack diversity and tend to overfit the dataset.

Project Challenge

The key challenge is to train a GAN (Generative Adversarial Network) that can:

  • Generate sharp and visually diverse handwritten digits
  • Maintain training stability throughout epochs
  • Avoid common pitfalls such as mode collapse and vanishing gradients

Problem Statement

Generating realistic handwritten digit images requires a model capable of learning complex data distributions from limited training samples. Traditional image generation methods often lack diversity and tend to overfit the dataset.

Project Challenge

The key challenge is to train a GAN (Generative Adversarial Network) that can:

  • Generate sharp and visually diverse handwritten digits
  • Maintain training stability throughout epochs
  • Avoid common pitfalls such as mode collapse and vanishing gradients

Objective

To implement and train a Generative Adversarial Network (GAN) on the MNIST dataset that learns to generate synthetic 28×28 grayscale handwritten digit images.

Project Goals
  • Monitor training progression using generated image samples and loss plots
  • Perform hyperparameter tuning to identify optimal training configurations
  • Justify all architectural and design decisions for both generator and discriminator
  • Explore advanced GAN variants such as:
    • DCGAN – Deep Convolutional GAN
    • WGAN – Wasserstein GAN for better stability
    • StyleGAN – For high-quality and stylized image synthesis

Proposed Solution

Project Components
  • Generator: Upsamples random noise into 28×28 grayscale images using Dense and Conv2DTranspose layers for feature expansion and image shaping.
  • Discriminator: Classifies real vs. fake images using Conv2D layers with Dropout for regularization and overfitting control.
  • Custom GAN Class: Encapsulates the full training loop with a custom train_step() method for adversarial updates of both generator and discriminator.
  • LossPlotCallback: Custom Keras callback to visualize and save generated image samples and loss trends after each epoch, aiding in training diagnostics.

Technologies Used

Tools & Technologies
Frameworks & Libraries
  • TensorFlow / Keras – For building and training GAN models
  • Matplotlib / PIL – For saving generated images and plotting loss curves
Dataset
  • MNIST – Grayscale handwritten digit dataset (28×28)
Training Configuration
  • Adam Optimizer – With learning rate and β₁ tuning for stability
  • Binary Cross-Entropy – Used as the loss function for both generator and discriminator
Development Environment
  • Jupyter Notebook / Python – For implementation and experimentation

Challenges Faced

Training Challenges & Observations
  • Generator instability observed at higher epochs, requiring monitoring and early stopping
  • Mode collapse risk was mitigated using dropout in the discriminator and label smoothing (0.9 for real labels)
  • Spiking generator loss indicated increasing difficulty in fooling a strong discriminator
  • Small image resolution (28×28) limited the expressive power of the generator

Methodology

Generator Architecture
  • Input: 100-dimensional latent vector
  • Dense layer: Output shape 7×7×256
  • Conv2DTranspose layers:
    • 128 filters → 7×7
    • 64 filters → 14×14
    • 1 filter → 28×28 (final output)
  • Activations: LeakyReLU for intermediate layers, Tanh for output
  • Output: 28×28 grayscale image
Discriminator Architecture
  • Input: 28×28×1 image
  • Conv2D layers:
    • 64 filters + Dropout
    • 128 filters + Dropout
  • Output: Flatten → Dense → Sigmoid (probability real/fake)
GAN Training Loop
  • For each batch:
    • Result / Outcome

      🖼️ Visual Output
      • Epoch 1: Pure random noise
      • Epoch 10: Digit-like shapes begin to emerge
      • Epoch 30: Recognizable, sharp handwritten digits
      📉 Loss Behavior
      Epoch Discriminator Loss (↓) Generator Loss (↑)
      1 0.45 0.44
      10 0.65 0.91
      30 0.64 0.95
      • Discriminator loss remained relatively stable across epochs
      • Generator loss increased — expected behavior as it learns to fool the discriminator
      🔁 Hyperparameter Tuning
      • Tested configurations:
        • Learning rates: 0.0002, 0.0001
        • β₁ (Adam): 0.5, 0.4
        • Latent dimensions: 100, 128
      • Best configuration:
        • Learning rate = 0.0002
        • β₁ = 0.4
        • Latent dimension = 100
      • Final loss metrics: Generator = 0.80, Discriminator = 0.63

EDA
ML MODEL