项目简介
本项目基于PyTorch框架,主要针对MNIST手写数字识别数据集开展图像分类工作。项目完整涵盖数据预处理、模型构建、训练、验证以及保存等流程,同时借助TensorBoard实现训练过程的可视化。此外,还具备自定义数据集加载和展示功能,方便用户按需调整输入数据。
项目的主要特性和功能
- 数据预处理:下载MNIST数据集,调整图像大小并保存为PNG格式,生成CSV文件记录图像路径和标签信息,便于加载自定义数据集。
- 模型构建:定义包含两个卷积层和一个全连接层的卷积神经网络(CNN)模型,支持自定义输入大小以适应不同数据集。
- 模型训练:使用定义好的模型进行训练,包含前向传播、损失计算、反向传播和优化步骤,通过TensorBoard可视化训练过程中的损失和准确率。
- 模型验证:用测试集对训练好的模型进行测试,计算并打印准确率。
- 模型保存:保存训练好的模型权重,方便后续使用或继续训练。
- 数据展示:加载自定义数据集,展示一批图像,检查数据预处理效果。
安装使用步骤
安装Anaconda
根据Anaconda安装教程安装最新版本的Anaconda。
设置虚拟环境
打开命令提示符(CMD),使用以下命令创建并激活虚拟环境:
bash
conda create --name myenv python=3.7
conda activate myenv
下载项目代码
使用以下命令复制项目代码到本地:
bash
安装所需依赖
在虚拟环境中安装所需的依赖库:
bash
conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch
conda install tensorflow==1.14.0=mkl*
conda install opencv
conda install spyder
conda install -c conda-forge tqdm
运行项目
使用以下命令启动Spyder IDE:
bash
spyder
在Spyder中打开项目目录下的NN.py
或CNN.py
文件,按F5
运行代码。
使用TensorBoard
在训练过程中,使用TensorBoard查看训练日志,监控损失和准确率的变化。
注意事项
- 确保Python环境已正确安装所需库。
- 修改配置文件路径以适应本地环境的数据集路径。
- 可根据项目需求修改模型结构和超参数。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】