Project Analysis
GAN Model for MNIST Digit Generation
About Project
Generating realistic handwritten digit images requires a model capable of learning complex data distributions from limited training samples. Traditional image generation methods often lack diversity and tend to overfit the dataset.
The key challenge is to train a GAN (Generative Adversarial Network) that can:
- Generate sharp and visually diverse handwritten digits
- Maintain training stability throughout epochs
- Avoid common pitfalls such as mode collapse and vanishing gradients
Problem Statement
Generating realistic handwritten digit images requires a model capable of learning complex data distributions from limited training samples. Traditional image generation methods often lack diversity and tend to overfit the dataset.
The key challenge is to train a GAN (Generative Adversarial Network) that can:
- Generate sharp and visually diverse handwritten digits
- Maintain training stability throughout epochs
- Avoid common pitfalls such as mode collapse and vanishing gradients
Objective
To implement and train a Generative Adversarial Network (GAN) on the MNIST dataset that learns to generate synthetic 28×28 grayscale handwritten digit images.
- Monitor training progression using generated image samples and loss plots
- Perform hyperparameter tuning to identify optimal training configurations
- Justify all architectural and design decisions for both generator and discriminator
- Explore advanced GAN variants such as:
- DCGAN – Deep Convolutional GAN
- WGAN – Wasserstein GAN for better stability
- StyleGAN – For high-quality and stylized image synthesis
Proposed Solution
-
Generator: Upsamples random noise into 28×28 grayscale images using
DenseandConv2DTransposelayers for feature expansion and image shaping. -
Discriminator: Classifies real vs. fake images using
Conv2Dlayers withDropoutfor regularization and overfitting control. -
Custom GAN Class: Encapsulates the full training loop with a custom
train_step()method for adversarial updates of both generator and discriminator. - LossPlotCallback: Custom Keras callback to visualize and save generated image samples and loss trends after each epoch, aiding in training diagnostics.
Technologies Used
- TensorFlow / Keras – For building and training GAN models
- Matplotlib / PIL – For saving generated images and plotting loss curves
- MNIST – Grayscale handwritten digit dataset (28×28)
- Adam Optimizer – With learning rate and β₁ tuning for stability
- Binary Cross-Entropy – Used as the loss function for both generator and discriminator
- Jupyter Notebook / Python – For implementation and experimentation
Challenges Faced
- Generator instability observed at higher epochs, requiring monitoring and early stopping
- Mode collapse risk was mitigated using dropout in the discriminator and label smoothing (0.9 for real labels)
- Spiking generator loss indicated increasing difficulty in fooling a strong discriminator
- Small image resolution (28×28) limited the expressive power of the generator
Methodology
- Input: 100-dimensional latent vector
- Dense layer: Output shape 7×7×256
- Conv2DTranspose layers:
- 128 filters → 7×7
- 64 filters → 14×14
- 1 filter → 28×28 (final output)
- Activations: LeakyReLU for intermediate layers, Tanh for output
- Output: 28×28 grayscale image
- Input: 28×28×1 image
- Conv2D layers:
- 64 filters + Dropout
- 128 filters + Dropout
- Output: Flatten → Dense → Sigmoid (probability real/fake)
- For each batch:
- Epoch 1: Pure random noise
- Epoch 10: Digit-like shapes begin to emerge
- Epoch 30: Recognizable, sharp handwritten digits
- Discriminator loss remained relatively stable across epochs
- Generator loss increased — expected behavior as it learns to fool the discriminator
- Tested configurations:
- Learning rates: 0.0002, 0.0001
- β₁ (Adam): 0.5, 0.4
- Latent dimensions: 100, 128
- Best configuration:
- Learning rate = 0.0002
- β₁ = 0.4
- Latent dimension = 100
- Final loss metrics: Generator = 0.80, Discriminator = 0.63
- Result / Outcome
| Epoch | Discriminator Loss (↓) | Generator Loss (↑) |
|---|---|---|
| 1 | 0.45 | 0.44 |
| 10 | 0.65 | 0.91 |
| 30 | 0.64 | 0.95 |