项目简介
本项目基于PyTorch框架开展图像分类模型的训练。功能涵盖模型的定义、训练、验证、测试以及模型参数的保存与加载等操作。训练使用CIFAR - 10数据集,该数据集有10个类别的彩色图像,为模型提供丰富素材。项目构建简单的卷积神经网络模型,结合交叉熵损失函数和随机梯度下降优化器进行训练,实现精准图像分类。
项目的主要特性和功能
- 数据加载与预处理:利用PyTorch的DataLoader模块从本地目录高效加载CIFAR10数据集,并预处理数据以符合模型输入要求。
- 模型定义:通过PyTorch的nn模块定义神经网络模型,包含多个卷积层、池化层和全连接层,具备特征提取和分类能力。
- 损失函数与优化器:采用交叉熵损失函数衡量预测结果与真实标签差异,使用随机梯度下降优化器(学习率0.01)更新模型参数。
- 训练循环:支持多轮训练,每轮对数据批次处理,计算输出、损失值,通过反向传播更新模型参数。
- 测试集评估:每轮训练后自动评估测试集,计算整体误差和准确率,反映模型在未知数据上的性能。
- 模型保存与加载:训练中按轮次保存模型参数,方便后续训练或预测;可快速加载参数恢复模型状态。
- 可视化展示:使用Tensorboard工具可视化训练过程中的损失和准确率等信息。
- 图像格式转换:借助torchvision的transforms模块将不同格式图像转换为Tensor格式。
- GPU加速:支持使用GPU训练,通过torch.device指定训练设备提升速度。
安装使用步骤
环境准备
- 从官网(www.anaconda.com)下载并安装anaconda。
- 修改Anaconda文件夹权限,避免在C盘创建虚拟环境。
- 打开命令行工具,创建pytorch虚拟环境,Python版本不高于3.9:
powershell conda create -n pytorch python=3.8
- 激活pytorch虚拟环境:
powershell conda activate pytorch
- 前往https://pytorch.org/选择相关配置,获取pytorch包的下载命令,例如:
powershell conda install pytorch torchvision torchaudio pytorch - cuda=12.1 -c pytorch -c nvidia
- 测试pytorch是否安装成功:
bash python import torch torch.cuda.is_available()
- 若使用Jupyter工具,在pytorch虚拟环境中安装Jupyter包和nb_conda:
powershell conda install jupyter notebook conda install nb_conda
测试Jupyter是否安装成功:powershell jupyter notebook
项目运行
- 准备数据:项目自动从指定路径加载CIFAR10数据集,确保数据集路径与代码一致,若不存在会自动下载。
- 运行训练代码:在激活的pytorch虚拟环境中,进入项目代码目录,运行
train_model.py
脚本开始训练:bash python train_model.py
- 查看训练结果:训练中TensorBoard记录损失和准确率信息。在命令行输入以下命令启动TensorBoard:
bash tensorboard --logdir=../model_logs --port=6007
然后在浏览器中打开http://localhost:6007
查看可视化结果。 - 测试模型:训练结束后,运行
test.py
脚本测试模型。准备测试图像,转换为合适格式,加载训练好的模型进行预测: ```python import torch from torchvision import transforms from PIL import Image from model import Model
image_path = "../imgs/ship.png" image = Image.open(image_path).convert("RGB") transform = transforms.Compose([transforms.Resize([32, 32]), transforms.ToTensor()]) image = transform(image) image = torch.reshape(image, [1, 3, 32, 32])
model = Model() model.load_state_dict(torch.load("model_30.pth"))
model.eval() with torch.no_grad(): out = model(image) print(out.argmax(1).item()) ```
Conda常见命令
- 查看虚拟环境:
powershell conda info --env
- 删除虚拟环境:
powershell conda remove -n env_name --all
- 更新conda命令:
powershell conda update -n base -c defaults conda
- 查看源:
powershell conda config --show - sources
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】