项目简介
本项目运用Swin - Transformer和Query2Label的方案(模型简称为QST)来解决Plant Pathology 2021 - FGVC8的叶片分类问题。通过将12类的混合分类简化为6类,采用独热码对标签分类,搭建Swin Transformer并在后端加入Query2Label模型,实现叶片的多标签分类。
项目的主要特性和功能
- 模型架构:以Swin Transformer作为特征提取器,结合Query2Label网络优化多标签分类结果,能有效捕获图像空间信息和上下文信息,提升标签分类关联性。
- 编码方式:把复杂的12类混合分类简化为6类,用独热码处理多标签情况,便于模型学习和分类。
- 评估指标:提供PR曲线和F1分数曲线等性能指标,可直观评估模型性能。
- 预测功能:输入叶片图像,能输出其属于不同标签分类的原始概率和经过sigmoid函数后的概率,从而判断叶片所属类别。
安装使用步骤
环境配置
本项目所需的Python环境无特殊要求,可在自己的torch环境中配置。在项目根目录下,执行以下命令安装依赖:
bash
pip install -r requirements.txt
数据集和权重下载
数据集合并
由于独热编码对数据集进行了改动,需将仓库中的plant_dataset
文件夹和下载得到的plant_dataset
文件夹进行合并,合并后格式如下:
├─train
| ├─images
| ├─train_label.csv
| ├─labels.csv
├─val
| ├─images
| ├─val_label.csv
| ├─labels.csv
├─test
| ├─images
| ├─test_label.csv
| ├─labels.csv
训练
在命令行中执行以下命令开始训练:
bash
python train.py --data-path <path of plant_dataset> --weights <path of your weight>
若不设置weights
,则从随机参数开始训练。
预测
使用以下命令进行预测:
bash
python predict.py --img-path <path of plant_dataset> --weights <path of your weight>
终端输出两行结果,分别代表属于不同标签分类的原始概率和经过sigmoid函数后的概率,可据此判断叶片所属类别。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】