模型训练与预测全流程解析¶
1. 数据集与模型选择¶
- CIFAR-10数据集:包含10类32×32彩色图像,共60000张,训练集50000张,测试集10000张。
- 模型对比:
- VGG16:使用BN层和Dropout解决梯度弥散,16层卷积网络结构。
- ResNet:残差网络通过shortcut连接解决深层网络退化问题,32层基础结构。
2. 训练代码核心逻辑¶
class TestCIFAR:
def __init__(self):
paddle.init(use_gpu=False, trainer_count=2) # 初始化PaddlePaddle,关闭GPU
def get_parameters(self, parameters_path=None, cost=None):
if parameters_path:
# 加载预训练参数
with open(parameters_path, 'r') as f:
return paddle.parameters.Parameters.from_tar(f)
else:
# 创建新参数
return paddle.parameters.create(cost)
def get_trainer(self):
datadim = 3 * 32 * 32 # 图像维度 (3通道×32×32)
lbl = paddle.layer.data(name="label", type=paddle.data_type.integer_value(10))
# 选择模型(VGG或ResNet)
out = vgg_bn_drop(datadim) # VGG模型
# out = resnet_cifar10(datadim) # ResNet模型
cost = paddle.layer.classification_cost(input=out, label=lbl)
parameters = self.get_parameters(cost=cost)
# 优化器配置
optimizer = paddle.optimizer.Momentum(
momentum=0.9,
regularization=paddle.optimizer.L2Regularization(rate=0.0002 * 128),
learning_rate=0.1 / 128.0,
learning_rate_decay_a=0.1,
learning_rate_decay_b=50000 * 100,
learning_rate_schedule="discexp"
)
return paddle.trainer.SGD(cost=cost, parameters=parameters, update_equation=optimizer)
def start_trainer(self):
# 加载训练数据(自动缓存)
reader = paddle.batch(
reader=paddle.reader.shuffle(paddle.dataset.cifar.train10(), buf_size=50000),
batch_size=128
)
feeding = {"image": 0, "label": 1} # 数据映射
# 训练事件处理
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print(f"Pass {event.pass_id}, Batch {event.batch_id}, Cost {event.cost}")
if isinstance(event, paddle.event.EndPass):
# 保存模型
model_path = "../model"
os.makedirs(model_path, exist_ok=True)
with open(f"{model_path}/model.tar", "w") as f:
trainer.save_parameter_to_tar(f)
# 测试集验证
result = trainer.test(
reader=paddle.batch(paddle.dataset.cifar.test10(), batch_size=128),
feeding=feeding
)
print(f"Test Pass {event.pass_id}, Result {result.metrics}")
trainer = self.get_trainer()
trainer.train(
reader=reader,
num_passes=100,
event_handler=event_handler,
feeding=feeding
)
3. 训练关键参数¶
- 优化器:Momentum+L2正则化(防止过拟合)
- 学习率:初始0.1/128,随训练衰减
- 批次大小:128(GPU内存有限时可减小)
- 训练轮数:100轮(可根据测试集准确率调整)
4. 预测代码实现¶
def to_prediction(self, image_path, parameters, out):
def load_image(file):
im = Image.open(file).resize((32, 32), Image.ANTIALIAS)
im = np.array(im).astype(np.float32)
im = im.transpose((2, 0, 1)) # HWC→CHW
im = im[(2, 1, 0), :, :] # RGB→BGR(匹配CIFAR格式)
return im.flatten() / 255.0 # 归一化
test_data = [(load_image(image_path),)]
probs = paddle.infer(output_layer=out, parameters=parameters, input=test_data)
lab = np.argsort(-probs)[0][0] # 取概率最大的类别
return lab, probs[0][lab]
5. 可视化训练过程¶
from paddle.v2.plot import Ploter
cost_ploter = Ploter("Train Cost", "Test Cost")
def event_handler_plot(event):
global step
if isinstance(event, paddle.event.EndIteration):
cost_ploter.append("Train Cost", step, event.cost)
cost_ploter.plot()
if isinstance(event, paddle.event.EndPass):
cost_ploter.append("Test Cost", step, result.cost)
trainer.train(reader=reader, event_handler=event_handler_plot)
6. 项目结构与扩展¶
- 模型文件:
vgg.py、resnet.py分别定义VGG16和ResNet模型 - 训练脚本:
train.py实现完整训练流程 - 预测脚本:
infer.py加载模型进行单张图像预测 - 扩展建议:
- 增加早停机制(验证集准确率不再提升时停止)
- 尝试更大批次或学习率调整
- 对比不同模型(VGG vs ResNet)的收敛速度与精度
7. 结果验证¶
- 训练输出:每100批次打印训练成本,每轮结束保存模型并测试
- 测试集结果:最终错误率稳定在15%左右(VGG模型)
- 预测示例:输入图像路径返回类别(0-9)及置信度
# 调用预测
testCIFAR = TestCIFAR()
parameters = testCIFAR.get_parameters("../model/model.tar")
result, prob = testCIFAR.to_prediction("airplane1.png", parameters, out)
print(f"预测类别: {result}, 置信度: {prob:.4f}")
总结¶
通过PaddlePaddle实现CIFAR-10图像分类任务,重点在于:
1. 合理选择模型结构(VGG/ResNet)
2. 配置优化器与正则化参数
3. 利用事件处理机制监控训练过程
4. 预处理图像格式适配CIFAR-10标准
该框架可灵活扩展至其他图像分类任务,只需替换模型或调整输入维度即可。