littlebot
Published on 2025-04-08 / 1 Visits
0

【源码】基于Python和PyTorch的分层模型训练与Web服务

项目简介

本项目是一个运用Python和PyTorch构建的分层模型训练应用,借助滑窗策略对预训练模型进行fine - tuning,以提升模型性能。同时,项目提供了基于Flask的Web服务,用户可上传文本数据并获取实体识别结果。

项目的主要特性和功能

  1. 滑窗策略:采用滑窗策略对预训练模型进行fine - tuning,提升模型性能。
  2. Web服务:提供基于Flask的Web服务,用户通过HTTP请求发送文本数据,获取实体识别结果。
  3. 实体识别:模型可识别文本中的实体,如货币资金、合并财务报表附注等。
  4. 对抗训练:支持在模型训练中使用对抗训练技术,增强模型鲁棒性。
  5. 日志和监控:具备日志记录功能,用于监控训练过程和结果。

安装使用步骤

环境准备

  • 创建并激活虚拟环境: bash conda create -n yourname python==3.10.0 conda activate yourname
  • 安装依赖: bash conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch conda install transformers==4.29.2 pip install rich==12.5.1 flask gevent

数据准备

  • 制造样本并导入Label_Studio平台进行标注。
  • 导出标注结果,模仿项目中的data/财务附注定位目录结构,创建新项目并拷贝preprocess.py

训练模型

  • 执行preprocess.py进行数据预处理。
  • 修改config.py中的data_dir参数。
  • 修改train.py脚本,新增一个if分支,然后执行train.py进行模型训练。

启动Web服务

  • 修改server.py中的model_name变量。
  • 运行server.py启动Web服务。

测试Web服务

使用test.py脚本测试Web服务的性能和准确性。

查询服务

通过HTTP请求访问Web服务,上传文本数据并获取实体识别结果。

注意事项

  1. 模型优化:通过优化样本和对抗训练可进一步提高模型性能。
  2. 数据格式:Web服务要求输入数据为JSON格式的字符串列表。
  3. 服务地址:Web服务默认监听10.17.107.66的10089端口。
  4. 模型训练:项目的模型训练依赖特定的数据结构和格式,需按规定的预处理步骤进行。

更新日志

  • 2024 - 01 - 18:修改了server.pypreprocess.py,补充提供了样例json文件。
  • 2024 - 03 - 18:修改了server.py,提供了基于Docker的升级说明。

下载地址

点击下载 【提取码: 4003】【解压密码: www.makuang.net】