项目简介
本项目是基于PyTorch框架的文本分类系统,用于处理文本分类任务。系统涵盖数据预处理、模型训练和预测三大模块,支持从预训练词向量文件构建词嵌入,采用CNN模型开展文本分类工作。
项目的主要特性和功能
- 数据预处理
- 构建词表和标签表。
- 从预训练文件构建词嵌入,对未出现的词随机初始化。
- 对数据顺序编号,实现高效加载。
- 统计句子长度并输出百分位值。
- 模型训练
- 运用CNN模型进行文本分类训练。
- 支持自定义模型和训练参数。
- 支持GPU加速训练。
- 预测
- 利用训练好的模型对新句子分类标注。
- 支持批量处理和GPU加速。
安装使用步骤
安装依赖
确保已安装以下Python库: - gensim==2.3.0 - numpy==1.13.1 - torch==0.2.0.post3 - torchvision==0.1.9
使用pip安装:
bash
pip install gensim==2.3.0 numpy==1.13.1 torch==0.2.0.post3 torchvision==0.1.9
使用步骤
- 数据预处理
- 处理训练数据:
bash python3 preprocessing.py -l --pd ./data/train.txt --ri ./data/train_idx/ --rv ./res/voc/ --re ./res/embed/ --pe ./path_to_embed_file
- 处理测试数据:
bash python3 preprocessing.py --pd ./data/test.txt --ri ./data/test_idx/
- 处理训练数据:
- 模型训练
使用预处理生成的数据训练模型:
bash CUDA_VISIBLE_DEVICES=0,1 python3 train.py --nc 2 --ml 40 --fs 3,4,5 --fn 400,300,200 --wd 64 --bs 256 -g
- 模型预测
使用训练好的模型对新句子分类标注:
bash CUDA_VISIBLE_DEVICES=0,1 python3 test.py --bs 256 -g --pr ./result.txt
注意事项
- 确保数据格式正确,按需调整命令行参数。
- 遇到问题可参考项目帮助文档和错误提示排查。
- 本项目使用PyTorch框架,确保环境已正确安装PyTorch及其依赖库。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】