Foreword¶
In Chapter 6, we introduced Generative Adversarial Networks (GANs) and used them to train the MNIST dataset to generate handwritten digit images. In this chapter, we will use GANs to train our own image dataset and generate images. Unlike the grayscale single-channel images used in Chapter 6, we will use 3-channel color images in this chapter.
GitHub Address: https://github.com/yeyupiaoling/LearnPaddle2/tree/master/note13
Define Data Reading¶
First, create an image_reader.py file to read our self-defined image dataset. Import the required dependency packages:
import os
import random
from multiprocessing import cpu_count
import numpy as np
import paddle
from PIL import Image
The image preprocessing mainly involves proportional compression, central cropping, random horizontal flipping (to avoid deformation during resize and enhance dataset diversity), and converting single-channel images to 3-channel images (to prevent training interruptions due to single-channel images):
# Image preprocessing for training
def train_mapper(sample):
img, crop_size = sample
img = Image.open(img)
# Random horizontal flip
r1 = random.random()
if r1 > 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
# Proportional scaling and central cropping
width = img.size[0]
height = img.size[1]
if width < height:
ratio = width / crop_size
width = width / ratio
height = height / ratio
img = img.resize((int(width), int(height)), Image.ANTIALIAS)
height = height / 2
crop_size2 = crop_size / 2
box = (0, int(height - crop_size2), int(width), int(height + crop_size2))
else:
ratio = height / crop_size
height = height / ratio
width = width / ratio
img = img.resize((int(width), int(height)), Image.ANTIALIAS)
width = width / 2
crop_size2 = crop_size / 2
box = (int(width - crop_size2), 0, int(width + crop_size2), int(height))
img = img.crop(box)
img = img.resize((crop_size, crop_size), Image.ANTIALIAS)
# Convert single-channel to 3-channel
if len(img.getbands()) == 1:
img1 = img2 = img3 = img
img = Image.merge('RGB', (img1, img2, img3))
# Convert to numpy array
img = np.array(img).astype(np.float32)
# Convert to CHW format
img = img.transpose((2, 0, 1))
# Convert to BGR and normalize to [0, 1]
img = img[(2, 1, 0), :, :] / 255.0
return img
Since we don’t need a data list for training (we just train all images to generate new ones), we read all images in the dataset directly:
# Image reader for training
def train_reader(train_image_path, crop_size):
pathss = []
for root, dirs, files in os.walk(train_image_path):
path = [os.path.join(root, name) for name in files]
pathss.extend(path)
def reader():
for line in pathss:
yield line, crop_size
return paddle.reader.xmap_readers(train_mapper, reader, cpu_count(), 1024)
Train the Generation Model¶
Create train.py to train the GAN and generate/save images during training. Import required dependencies:
import os
import shutil
import numpy as np
import paddle
import paddle.fluid as fluid
import matplotlib.pyplot as plt
import image_reader
The generator aims to produce images that deceive the discriminator. It consists of fully connected layers, batch normalization layers, and transposed convolutions. The output size is 3 (RGB channels) with sigmoid activation to ensure values between 0-1:
# Image size for training
image_size = 112
# Define Generator
def Generator(y, name="G"):
def deconv(x, num_filters, filter_size=5, stride=2, dilation=1, padding=2, output_size=None, act=None):
return fluid.layers.conv2d_transpose(input=x,
num_filters=num_filters,
output_size=output_size,
filter_size=filter_size,
stride=stride,
dilation=dilation,
padding=padding,
act=act)
with fluid.unique_name.guard(name + "/"):
# First fully connected + BN
y = fluid.layers.fc(y, size=2048)
y = fluid.layers.batch_norm(y)
# Second fully connected + BN
y = fluid.layers.fc(y, size=int(128 * (image_size / 4) * (image_size / 4)))
y = fluid.layers.batch_norm(y)
# Reshape
y = fluid.layers.reshape(y, shape=[-1, 128, int((image_size / 4)), int((image_size / 4))])
# Transposed convolution 1
y = deconv(x=y, num_filters=128, act='relu', output_size=[int((image_size / 2)), int((image_size / 2))])
# Transposed convolution 2 (output: 3 channels, sigmoid for [0,1])
y = deconv(x=y, num_filters=3, act='sigmoid', output_size=[image_size, image_size])
return y
The discriminator classifies images as real or fake. It uses convolutional/pooling layers and a final sigmoid output:
# Define Discriminator
def Discriminator(images, name="D"):
def conv_pool(input, num_filters, act=None):
return fluid.nets.simple_img_conv_pool(input=input,
filter_size=3,
num_filters=num_filters,
pool_size=2,
pool_stride=2,
act=act)
with fluid.unique_name.guard(name + "/"):
y = fluid.layers.reshape(x=images, shape=[-1, 3, image_size, image_size])
# First conv-pool block
y = conv_pool(input=y, num_filters=64, act='leaky_relu')
# Second conv-pool + BN
y = conv_pool(input=y, num_filters=128)
y = fluid.layers.batch_norm(input=y, act='leaky_relu')
# Fully connected layer
y = fluid.layers.fc(input=y, size=1024)
y = fluid.layers.batch_norm(input=y, act='leaky_relu')
# Output: single channel (sigmoid for binary classification)
y = fluid.layers.fc(input=y, size=1, act='sigmoid')
return y
Define programs for training discriminator (real/fake) and generator:
# Train programs for D (real/fake) and G
train_d_fake = fluid.Program()
train_d_real = fluid.Program()
train_g = fluid.Program()
startup = fluid.Program()
# Noise dimension
z_dim = 100
# Get parameters by prefix from a program
def get_params(program, prefix):
all_params = program.global_block().all_parameters()
return [t.name for t in all_params if t.name.startswith(prefix)]
Train discriminator on real images (label=1):
# Train D on real images
with fluid.program_guard(train_d_real, startup):
real_image = fluid.layers.data('image', shape=[3, image_size, image_size])
ones = fluid.layers.fill_constant_batch_size_like(real_image, shape=[-1, 1], dtype='float32', value=1)
p_real = Discriminator(real_image)
real_cost = fluid.layers.sigmoid_cross_entropy_with_logits(p_real, ones)
real_avg_cost = fluid.layers.mean(real_cost)
d_params = get_params(train_d_real, "D")
optimizer = fluid.optimizer.Adam(learning_rate=2e-4)
optimizer.minimize(real_avg_cost, parameter_list=d_params)
Train discriminator on fake images (label=0):
# Train D on fake images
with fluid.program_guard(train_d_fake, startup):
z = fluid.layers.data(name='z', shape=[z_dim])
zeros = fluid.layers.fill_constant_batch_size_like(z, shape=[-1, 1], dtype='float32', value=0)
p_fake = Discriminator(Generator(z))
fake_cost = fluid.layers.sigmoid_cross_entropy_with_logits(p_fake, zeros)
fake_avg_cost = fluid.layers.mean(fake_cost)
d_params = get_params(train_d_fake, "D")
optimizer = fluid.optimizer.Adam(learning_rate=2e-4)
optimizer.minimize(fake_avg_cost, parameter_list=d_params)
Train generator to fool discriminator (label=1):
# Train G to generate fake images
fake = None
with fluid.program_guard(train_g, startup):
z = fluid.layers.data(name='z', shape=[z_dim])
ones = fluid.layers.fill_constant_batch_size_like(z, shape=[-1, 1], dtype='float32', value=1)
fake = Generator(z)
infer_program = train_g.clone(for_test=True)
p = Discriminator(fake)
g_cost = fluid.layers.sigmoid_cross_entropy_with_logits(p, ones)
g_avg_cost = fluid.layers.mean(g_cost)
g_params = get_params(train_g, "G")
optimizer = fluid.optimizer.Adam(learning_rate=2e-4)
optimizer.minimize(g_avg_cost, parameter_list=g_params)
Noise generator and image saving functions:
# Noise generator for GAN training
def z_reader():
while True:
yield np.random.uniform(-1.0, 1.0, (z_dim)).astype('float32')
# Save generated images
def show_image_grid(images):
for i, image in enumerate(images):
image = image.transpose((2, 1, 0))
save_path = 'train_image'
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.imsave(os.path.join(save_path, "test_%d.png" % i), image)
Start training with data loaders and execution:
# Data loaders
mydata_generator = paddle.batch(reader=image_reader.train_reader('datasets', image_size), batch_size=32)
z_generator = paddle.batch(z_reader, batch_size=32)()
test_z = np.array(next(z_generator))
# Execute with GPU (recommended for speed)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup)
# Training loop
for pass_id in range(100):
for i, real_image in enumerate(mydata_generator()):
# Train D on fake images
r_fake = exe.run(program=train_d_fake, fetch_list=[fake_avg_cost], feed={'z': test_z})
# Train D on real images
r_real = exe.run(program=train_d_real, fetch_list=[real_avg_cost], feed={'image': np.array(real_image)})
# Train G
r_g = exe.run(program=train_g, fetch_list=[g_avg_cost], feed={'z': test_z})
if i % 100 == 0:
print("Pass: %d, Batch: %d, D(real): %.5f, D(fake): %.5f, G: %.5f" %
(pass_id, i, r_real[0], r_fake[0], r_g[0]))
# Generate and save images after each epoch
r_i = exe.run(program=infer_program, fetch_list=[fake], feed={'z': test_z})
show_image_grid(np.array(r_i).astype(np.float32)[0])
# Save inference model
save_path = 'infer_model/'
shutil.rmtree(save_path, ignore_errors=True)
os.makedirs(save_path)
fluid.io.save_inference_model(save_path, feeded_var_names=[z.name], target_vars=[fake], executor=exe, main_program=train_g)
Generate Images Using the Model¶
Create infer.py to load the trained model and generate images:
import os
import paddle
import matplotlib.pyplot as plt
import numpy as np
import paddle.fluid as fluid
Load the inference model:
# Create executor (CPU for inference)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# Load saved inference model
save_path = 'infer_model/'
[infer_program, feeded_var_names, target_var] = fluid.io.load_inference_model(dirname=save_path, executor=exe)
Generate noise and predict:
# Noise generator
z_dim = 100
def z_reader():
while True:
yield np.random.uniform(-1.0, 1.0, (z_dim)).astype('float32')
z_generator = paddle.batch(z_reader, batch_size=32)()
test_z = np.array(next(z_generator))
# Save generated images
def save_image(images):
for i, image in enumerate(images):
image = image.transpose((2, 1, 0))
save_path = 'infer_image'
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.imsave(os.path.join(save_path, "test_%d.png" % i), image)
# Predict and save
r_i = exe.run(program=infer_program, feed={feeded_var_names[0]: test_z}, fetch_list=target_var)
r_i = np.array(r_i).astype(np.float32)
save_image(r_i[0])
print('Image generation complete')
Note: The model may need tuning for better results with complex images. For improvements, adjust hyperparameters, network architecture, or use techniques like progressive growing.
Previous Chapter: Chapter 12 - Custom Text Dataset Classification¶
Next Chapter: Chapter 14 - Deploying Models on Servers¶
References¶
- https://github.com/oraoto/learn_ml/blob/master/paddle/gan-mnist-split.ipynb
- https://www.cnblogs.com/max-hu/p/7129188.html
- https://blog.csdn.net/somtian/article/details/72126328