项目简介
本项目是为解决百度网盘AI大赛手写文字擦除任务而设计的方案。借助PaddlePaddle深度学习框架,构建并训练高效的神经网络模型,以擦除扫描文档中的手写文字,提升文档处理质量和用户体验。
项目的主要特性和功能
- 数据处理:分开训练推理官方提供的文档类和试卷类图片,为文档类数据重新生成高质量mask,运用腐蚀膨胀、降低阈值、分别计算三通道插值等方法。
- 模型设计:有分类模型(resnet18网络结构)、试卷模型和文档模型(EraseNet网络结构,输入用swin transformer提取全局特征)。文档模型各分支的decoder阶段叠加encoder多尺度特征,对mask用l2损失,对生成图片用重建损失训练。
- 训练优化:对文档类训练集样本padding,加入小角度旋转和水平翻转增强模型鲁棒性。针对分类模型样本不均衡问题,每个epoch调整训练数据。
- 测试策略:测试时图片输入size为512x512,分块推理再拼接;先经分类模型分类,再用对应类别模型推理;去掉mask,采用refinement分支输出图片;使用翻转镜像增强;加入padding,采用交错分块裁剪;采用多种策略优化推理时间。
安装使用步骤
- 环境准备:确保已安装PaddlePaddle框架和相应的依赖库。
- 数据准备:将数据集按指定文件路径放置,classone中是本届比赛提供的文档类数据集,dehw_train_dataset是上一届比赛的试卷类数据集。重新生成mask,将data_root和data_path改成对应的数据路径,运行
python generate_mask5.py
。 - 模型训练:指定数据集路径
dataRoot
和预训练模型路径pretrained
,训练分类网络运行python trainV.py
,训练手写文字擦除网络运行python trainNewMaskAugSchedule.py
。 - 模型推理:进入Final目录,运行
python predict.py src_image/ save_image
进行推理,最终将Final打包提交即可得到线上成绩。 - 预训练模型:可从https://aistudio.baidu.com/bj-cpu-01/user/446178/4503672/doc/tree/work/Final/model/STE_str_best.pdparams 下载预训练模型。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】