项目简介
本项目借助PyTorch框架,运用卷积神经网络(CNN)模型开展MNIST手写数字识别工作。项目涵盖数据预处理、模型定义、训练、测试以及模型保存等流程,同时利用Tensorboard进行日志记录与可视化。
项目的主要特性和功能
- 数据预处理:具备数据下载、大小调整、保存和加载等功能,用于处理MNIST手写数字数据集。
- 模型定义:定义了包含两个卷积层的卷积神经网络模型,用于识别手写数字。
- 训练过程:利用定义好的CNN模型进行训练,借助Tensorboard进行日志记录和可视化。
- 测试部分:在测试数据集上评估模型性能,计算测试准确率。
- 模型保存:训练完成后,保存模型的状态字典。
安装使用步骤
环境配置
- 安装Anaconda,设置虚拟环境,保证PyTorch和torchvision等库版本正确。
- 使用以下命令创建并激活虚拟环境:
terminal conda create --name myenv python=3.7 conda activate myenv
安装依赖
安装PyTorch、TensorFlow、OpenCV等必要的库:
terminal
conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch
conda install tensorflow==1.14.0=mkl*
conda install opencv
conda install spyder
conda install -c conda-forge tqdm
数据准备
- 使用
data_save.py
脚本下载MNIST数据集,将图像大小调整为64x64像素。 - 运行以下命令:
terminal spyder
- 在Spyder中打开并运行
data_save.py
。
模型训练
- 使用
MNIST_CNN.py
脚本进行模型训练,用logger.py
进行日志记录。 - 在Spyder中打开并运行
MNIST_CNN.py
。
模型测试
- 使用
MNIST_CNN.py
脚本进行模型测试,计算测试准确率。 - 在Spyder中打开并运行
MNIST_CNN.py
。
结果展示
- 使用
show_data.py
脚本展示数据集的部分图像。 - 在Spyder中打开并运行
show_data.py
。
注意事项
- 确保安装正确版本的PyTorch和其他依赖库。
- 使用
data_save.py
下载MNIST数据集时,确保网络连接正常。 - 训练模型时,确保有足够的计算资源,如GPU。
- 根据需要调整超参数,如训练周期、批量大小和学习率等。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】