项目简介
本项目是一个运用Python和PyTorch构建的分层模型训练应用,借助滑窗策略对预训练模型进行fine - tuning,以提升模型性能。同时,项目提供了基于Flask的Web服务,用户可上传文本数据并获取实体识别结果。
项目的主要特性和功能
- 滑窗策略:采用滑窗策略对预训练模型进行fine - tuning,提升模型性能。
- Web服务:提供基于Flask的Web服务,用户通过HTTP请求发送文本数据,获取实体识别结果。
- 实体识别:模型可识别文本中的实体,如货币资金、合并财务报表附注等。
- 对抗训练:支持在模型训练中使用对抗训练技术,增强模型鲁棒性。
- 日志和监控:具备日志记录功能,用于监控训练过程和结果。
安装使用步骤
环境准备
- 创建并激活虚拟环境:
bash conda create -n yourname python==3.10.0 conda activate yourname
- 安装依赖:
bash conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch conda install transformers==4.29.2 pip install rich==12.5.1 flask gevent
数据准备
- 制造样本并导入Label_Studio平台进行标注。
- 导出标注结果,模仿项目中的
data/财务附注定位
目录结构,创建新项目并拷贝preprocess.py
。
训练模型
- 执行
preprocess.py
进行数据预处理。 - 修改
config.py
中的data_dir
参数。 - 修改
train.py
脚本,新增一个if分支,然后执行train.py
进行模型训练。
启动Web服务
- 修改
server.py
中的model_name
变量。 - 运行
server.py
启动Web服务。
测试Web服务
使用test.py
脚本测试Web服务的性能和准确性。
查询服务
通过HTTP请求访问Web服务,上传文本数据并获取实体识别结果。
注意事项
- 模型优化:通过优化样本和对抗训练可进一步提高模型性能。
- 数据格式:Web服务要求输入数据为JSON格式的字符串列表。
- 服务地址:Web服务默认监听10.17.107.66的10089端口。
- 模型训练:项目的模型训练依赖特定的数据结构和格式,需按规定的预处理步骤进行。
更新日志
- 2024 - 01 - 18:修改了
server.py
和preprocess.py
,补充提供了样例json文件。 - 2024 - 03 - 18:修改了
server.py
,提供了基于Docker的升级说明。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】