项目简介
本项目基于PyTorch框架实现了LeNet - 5神经网络模型,用于在MNIST数据集上开展手写数字识别工作。项目包含模型定义、数据预处理、模型训练、模型预测等模块,通过搭建与训练卷积神经网络模型,达成图像的自动分类,为初学者提供了优质的学习范例。
项目的主要特性和功能
- 经典网络架构:运用LeNet - 5卷积神经网络架构,此网络结构简洁清晰,是CNN研究的基础。
- 数据预处理:对MNIST数据集实施预处理,把图像尺寸调整为32x32,并进行标准化处理,防止梯度丢失问题。
- 模型训练:定义单次训练和验证流程,采用Adam优化器与交叉熵损失函数进行模型训练,记录训练和验证损失,保存最佳模型。
- 测试功能:支持在MNIST验证集上随机抽取图片测试,也能对自己的手写数字图片进行识别。
安装使用步骤
假设用户已下载本项目的源码文件,可按如下步骤操作:
1. 环境准备:确保已安装Python和PyTorch。
2. 数据下载:运行代码中下载MNIST数据集的部分,数据集会自动下载到指定的data
目录。
python
train_dataset = datasets.MNIST(root='data', train=True, transform=transforms, download=True)
valid_dataset = datasets.MNIST(root='data', train=False, transform=transforms)
3. 模型训练:运行训练代码,设置训练轮数等参数,训练完成后最佳模型将保存到Model/LeNet5.pth
。
python
model, optimizer, _ = training_process(model, criterion, optimizer, train_loader, valid_loader, N_EPOCHS, DEVICE)
torch.save(model, "Model/LeNet5.pth")
4. 模型测试:
- MNIST验证集测试:使用predict.py
中的multi_predict()
函数随机从MNIST的验证集抽取50张进行测试。
- 手写数字测试:使用predict.py
中编写的single_predict()
函数,传入图片路径,输出检测结果。注意,若手写数字是白底黑字,需对灰度图进行反相处理。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】