模型训练与预测全流程解析

1. 数据集与模型选择

  • CIFAR-10数据集:包含10类32×32彩色图像,共60000张,训练集50000张,测试集10000张。
  • 模型对比
  • VGG16:使用BN层和Dropout解决梯度弥散,16层卷积网络结构。
  • ResNet:残差网络通过shortcut连接解决深层网络退化问题,32层基础结构。

2. 训练代码核心逻辑

class TestCIFAR:
    def __init__(self):
        paddle.init(use_gpu=False, trainer_count=2)  # 初始化PaddlePaddle,关闭GPU

    def get_parameters(self, parameters_path=None, cost=None):
        if parameters_path:
            # 加载预训练参数
            with open(parameters_path, 'r') as f:
                return paddle.parameters.Parameters.from_tar(f)
        else:
            # 创建新参数
            return paddle.parameters.create(cost)

    def get_trainer(self):
        datadim = 3 * 32 * 32  # 图像维度 (3通道×32×32)
        lbl = paddle.layer.data(name="label", type=paddle.data_type.integer_value(10))

        # 选择模型(VGG或ResNet)
        out = vgg_bn_drop(datadim)  # VGG模型
        # out = resnet_cifar10(datadim)  # ResNet模型

        cost = paddle.layer.classification_cost(input=out, label=lbl)
        parameters = self.get_parameters(cost=cost)

        # 优化器配置
        optimizer = paddle.optimizer.Momentum(
            momentum=0.9,
            regularization=paddle.optimizer.L2Regularization(rate=0.0002 * 128),
            learning_rate=0.1 / 128.0,
            learning_rate_decay_a=0.1,
            learning_rate_decay_b=50000 * 100,
            learning_rate_schedule="discexp"
        )

        return paddle.trainer.SGD(cost=cost, parameters=parameters, update_equation=optimizer)

    def start_trainer(self):
        # 加载训练数据(自动缓存)
        reader = paddle.batch(
            reader=paddle.reader.shuffle(paddle.dataset.cifar.train10(), buf_size=50000),
            batch_size=128
        )
        feeding = {"image": 0, "label": 1}  # 数据映射

        # 训练事件处理
        def event_handler(event):
            if isinstance(event, paddle.event.EndIteration):
                if event.batch_id % 100 == 0:
                    print(f"Pass {event.pass_id}, Batch {event.batch_id}, Cost {event.cost}")
            if isinstance(event, paddle.event.EndPass):
                # 保存模型
                model_path = "../model"
                os.makedirs(model_path, exist_ok=True)
                with open(f"{model_path}/model.tar", "w") as f:
                    trainer.save_parameter_to_tar(f)
                # 测试集验证
                result = trainer.test(
                    reader=paddle.batch(paddle.dataset.cifar.test10(), batch_size=128),
                    feeding=feeding
                )
                print(f"Test Pass {event.pass_id}, Result {result.metrics}")

        trainer = self.get_trainer()
        trainer.train(
            reader=reader,
            num_passes=100,
            event_handler=event_handler,
            feeding=feeding
        )

3. 训练关键参数

  • 优化器:Momentum+L2正则化(防止过拟合)
  • 学习率:初始0.1/128,随训练衰减
  • 批次大小:128(GPU内存有限时可减小)
  • 训练轮数:100轮(可根据测试集准确率调整)

4. 预测代码实现

def to_prediction(self, image_path, parameters, out):
    def load_image(file):
        im = Image.open(file).resize((32, 32), Image.ANTIALIAS)
        im = np.array(im).astype(np.float32)
        im = im.transpose((2, 0, 1))  # HWC→CHW
        im = im[(2, 1, 0), :, :]  # RGB→BGR(匹配CIFAR格式)
        return im.flatten() / 255.0  # 归一化

    test_data = [(load_image(image_path),)]
    probs = paddle.infer(output_layer=out, parameters=parameters, input=test_data)
    lab = np.argsort(-probs)[0][0]  # 取概率最大的类别
    return lab, probs[0][lab]

5. 可视化训练过程

from paddle.v2.plot import Ploter
cost_ploter = Ploter("Train Cost", "Test Cost")

def event_handler_plot(event):
    global step
    if isinstance(event, paddle.event.EndIteration):
        cost_ploter.append("Train Cost", step, event.cost)
        cost_ploter.plot()
    if isinstance(event, paddle.event.EndPass):
        cost_ploter.append("Test Cost", step, result.cost)

trainer.train(reader=reader, event_handler=event_handler_plot)

6. 项目结构与扩展

  • 模型文件vgg.pyresnet.py分别定义VGG16和ResNet模型
  • 训练脚本train.py实现完整训练流程
  • 预测脚本infer.py加载模型进行单张图像预测
  • 扩展建议
  • 增加早停机制(验证集准确率不再提升时停止)
  • 尝试更大批次或学习率调整
  • 对比不同模型(VGG vs ResNet)的收敛速度与精度

7. 结果验证

  • 训练输出:每100批次打印训练成本,每轮结束保存模型并测试
  • 测试集结果:最终错误率稳定在15%左右(VGG模型)
  • 预测示例:输入图像路径返回类别(0-9)及置信度
# 调用预测
testCIFAR = TestCIFAR()
parameters = testCIFAR.get_parameters("../model/model.tar")
result, prob = testCIFAR.to_prediction("airplane1.png", parameters, out)
print(f"预测类别: {result}, 置信度: {prob:.4f}")

总结

通过PaddlePaddle实现CIFAR-10图像分类任务,重点在于:
1. 合理选择模型结构(VGG/ResNet)
2. 配置优化器与正则化参数
3. 利用事件处理机制监控训练过程
4. 预处理图像格式适配CIFAR-10标准

该框架可灵活扩展至其他图像分类任务,只需替换模型或调整输入维度即可。

Xiaoye