Foreword

In the previous chapter, we used the MNIST dataset to train a model capable of classifying handwritten digits. If our dataset is insufficient to ensure the model converges, the most straightforward solution is to increase the dataset size. However, collecting and labeling data is time-consuming. Recently, Generative Adversarial Networks (GANs) have become popular for efficiently generating new data. GANs can generate additional images based on existing training data to achieve a level of realism indistinguishable from the original.

Training and Prediction

Create a GAN.py file. First, import the necessary Python packages. The matplotlib package will be used later to display the generated images.

import numpy as np
import paddle
import paddle.fluid as fluid
import matplotlib.pyplot as plt

Define the Network

A GAN consists of a generator and a discriminator. The generator aims to produce images that fool the discriminator, while the discriminator learns to distinguish between real and generated images. As training progresses, the discriminator becomes more accurate, and the generator produces increasingly realistic images. The generator here uses two fully connected layers with Batch Normalization (BN) and two transposed convolution operations. The final convolutional layer has 1 filter since the output is a grayscale handwritten digit image.

# Define the Generator
def Generator(y, name="G"):
    def deconv(x, num_filters, filter_size=5, stride=2, dilation=1, padding=2, output_size=None, act=None):
        return fluid.layers.conv2d_transpose(input=x,
                                             num_filters=num_filters,
                                             output_size=output_size,
                                             filter_size=filter_size,
                                             stride=stride,
                                             dilation=dilation,
                                             padding=padding,
                                             act=act)
    with fluid.unique_name.guard(name + "/"):
        # First fully connected and BN layer
        y = fluid.layers.fc(y, size=2048)
        y = fluid.layers.batch_norm(y)
        # Second fully connected and BN layer
        y = fluid.layers.fc(y, size=128 * 7 * 7)
        y = fluid.layers.batch_norm(y)
        # Reshape the tensor
        y = fluid.layers.reshape(y, shape=(-1, 128, 7, 7))
        # First transposed convolution
        y = deconv(x=y, num_filters=128, act='relu', output_size=[14, 14])
        # Second transposed convolution
        y = deconv(x=y, num_filters=1, act='tanh', output_size=[28, 28])
    return y

The discriminator is trained on real data and learns to classify generated images as fake. This is a binary classification task where the discriminator should output a probability of 1 for real images and 0 for fake images. The discriminator uses three convolutional-pooling layers and a final fully connected layer for binary classification.

# Discriminator
def Discriminator(images, name="D"):
    # Define a convolutional-pooling block
    def conv_pool(input, num_filters, act=None):
        return fluid.nets.simple_img_conv_pool(input=input,
                                               filter_size=5,
                                               num_filters=num_filters,
                                               pool_size=2,
                                               pool_stride=2,
                                               act=act)

    with fluid.unique_name.guard(name + "/"):
        y = fluid.layers.reshape(x=images, shape=[-1, 1, 28, 28])
        # First convolutional-pooling block
        y = conv_pool(input=y, num_filters=64, act='leaky_relu')
        # Second convolutional-pooling with BN
        y = conv_pool(input=y, num_filters=128)
        y = fluid.layers.batch_norm(input=y, act='leaky_relu')
        # Third fully connected layer with BN
        y = fluid.layers.fc(input=y, size=1024)
        y = fluid.layers.batch_norm(input=y, act='leaky_relu')
        # Final classification output
        y = fluid.layers.fc(input=y, size=1, act='sigmoid')
    return y

Define Training Program

We define four Programs: one for training the discriminator on real images, one for training on fake images, one for training the generator, and one for initializing parameters. The noise dimension is set to 100.

# Create Programs for training
train_d_fake = fluid.Program()  # Train D on fake images
train_d_real = fluid.Program()  # Train D on real images
train_g = fluid.Program()       # Train G to fool D
startup = fluid.Program()       # Program for parameter initialization
z_dim = 100                     # Noise dimension

Helper function to extract parameters from a Program by prefix:

# Get parameters from a Program by prefix
def get_params(program, prefix):
    all_params = program.global_block().all_parameters()
    return [t.name for t in all_params if t.name.startswith(prefix)]

Define the discriminator training program for real images:

# Train D to recognize real images
with fluid.program_guard(train_d_real, startup):
    real_image = fluid.layers.data('image', shape=[1, 28, 28])
    ones = fluid.layers.fill_constant_batch_size_like(real_image, shape=[-1, 1], dtype='float32', value=1)

    p_real = Discriminator(real_image)
    real_cost = fluid.layers.sigmoid_cross_entropy_with_logits(p_real, ones)
    real_avg_cost = fluid.layers.mean(real_cost)

    d_params = get_params(train_d_real, "D")

    optimizer = fluid.optimizer.AdamOptimizer(learning_rate=2e-4)
    optimizer.minimize(real_avg_cost, parameter_list=d_params)

Define the discriminator training program for fake images:

# Train D to recognize fake images
with fluid.program_guard(train_d_fake, startup):
    z = fluid.layers.data(name='z', shape=[z_dim, 1, 1])
    zeros = fluid.layers.fill_constant_batch_size_like(z, shape=[-1, 1], dtype='float32', value=0)

    p_fake = Discriminator(Generator(z))
    fake_cost = fluid.layers.sigmoid_cross_entropy_with_logits(p_fake, zeros)
    fake_avg_cost = fluid.layers.mean(fake_cost)

    d_params = get_params(train_d_fake, "D")

    optimizer = fluid.optimizer.AdamOptimizer(learning_rate=2e-4)
    optimizer.minimize(fake_avg_cost, parameter_list=d_params)

Define the generator training program:

# Train G to fool D
with fluid.program_guard(train_g, startup):
    z = fluid.layers.data(name='z', shape=[z_dim, 1, 1])
    ones = fluid.layers.fill_constant_batch_size_like(z, shape=[-1, 1], dtype='float32', value=1)

    fake = Generator(z)
    infer_program = train_g.clone(for_test=True)  # Clone for inference

    p = Discriminator(fake)
    g_cost = fluid.layers.sigmoid_cross_entropy_with_logits(p, ones)
    g_avg_cost = fluid.layers.mean(g_cost)

    g_params = get_params(train_g, "G")

    optimizer = fluid.optimizer.AdamOptimizer(learning_rate=2e-4)
    optimizer.minimize(g_avg_cost, parameter_list=g_params)

Training and Prediction

Define functions to generate noise and read the MNIST dataset:

# Noise generator for G
def z_reader():
    while True:
        yield np.random.normal(0.0, 1.0, (z_dim, 1, 1)).astype('float32')

# MNIST reader (ignores labels)
def mnist_reader(reader):
    def r():
        for img, label in reader():
            yield img.reshape(1, 28, 28)
    return r

Function to save generated images:

# Display generated images
def show_image_grid(images, pass_id=None):
    for i, image in enumerate(images[:64]):
        plt.imsave("image/test_%d.png" % i, image[0], cmap='Greys_r')

Set up data readers and the executor:

# Data readers
mnist_generator = paddle.batch(
    paddle.reader.shuffle(mnist_reader(paddle.dataset.mnist.train()), 30000), batch_size=128)
z_generator = paddle.batch(z_reader, batch_size=128)()

# Initialize executor (GPU recommended)
# place = fluid.CPUPlace()
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup)

# Test noise for prediction
test_z = np.array(next(z_generator))

Start training with alternating updates to D and G:

# Training loop
for pass_id in range(5):
    for i, real_image in enumerate(mnist_generator()):
        # Train D on fake images
        r_fake = exe.run(program=train_d_fake,
                         fetch_list=[fake_avg_cost],
                         feed={'z': np.array(next(z_generator))})

        # Train D on real images
        r_real = exe.run(program=train_d_real,
                         fetch_list=[real_avg_cost],
                         feed={'image': np.array(real_image)})

        # Train G to fool D
        r_g = exe.run(program=train_g,
                      fetch_list=[g_avg_cost],
                      feed={'z': np.array(next(z_generator))})

    print("Pass: %d, Fake Loss: %f, Real Loss: %f, G Loss: %f" % 
          (pass_id, r_fake[0][0], r_real[0][0], r_g[0][0]))

    # Generate and display test images
    test_images = exe.run(program=infer_program,
                          fetch_list=[fake],
                          feed={'z': test_z})
    show_image_grid(test_images[0], pass_id)

References

  1. https://www.cnblogs.com/max-hu/p/7129188.html
  2. https://github.com/oraoto/learn_ml/blob/master/paddle/gan-mnist-split.ipynb
  3. https://blog.csdn.net/somtian/article/details/72126328
  4. http://www.paddlepaddle.org/documentation/api/zh/1.1/layers.html#sigmoid-cross-entropy-with-logits

Previous Chapter: 《PaddlePaddle from Beginner to Wizard》5 - Recurrent Neural Networks

Next Chapter: 《PaddlePaddle from Beginner to Wizard》7 - Reinforcement Learning


Note: The latest code is on GitHub: https://github.com/yeyupiaoling/LearnPaddle2/tree/master/note6

Xiaoye