项目简介
本项目是基于GPT2的中文闲聊机器人,借助自然语言处理(NLP)技术达成与用户的自然交互。开发采用了Hugging Face的transformers
库以及PyTorch框架,同时参考了GPT2 - Chinese和DialoGPT等相关论文与开源项目。
项目的主要特性和功能
- 模型训练:运用预训练的GPT2模型,在中文闲聊语料上进行微调训练,增强模型的对话生成能力。
- 多轮对话:模型可处理多轮对话,每一轮均能生成对应回复。
- 交互式对话:通过命令行界面和用户交互,用户输入问题或主题,模型生成回复。
- 模型评估与可视化:提供对话长度的统计分析和可视化工具,便于了解对话多样性。
- 模型分享:提供预训练的模型供用户使用和评估。
安装使用步骤
环境准备
确保安装了Python 3.6、transformers
库(4.2.0版本或更高)和pytorch
库(1.7.0版本或更高)。
数据准备
准备中文闲聊语料,并按照项目的数据格式组织,语料格式为每段闲聊之间间隔一行。
数据预处理
运行preprocess.py
脚本,对数据进行tokenize并序列化保存到train.pkl
文件中。
bash
python preprocess.py --train_path data/train.txt --save_path data/train.pkl
模型训练
运行train.py
脚本,使用预处理后的数据进行模型训练。
bash
python train.py --epochs 40 --batch_size 8 --device 0 --train_path data/train.pkl
人机交互
运行interact.py
脚本,使用训练好的模型进行人机交互。
bash
python interact.py --no_cuda --model_path path_to_your_model --max_history_len 3
若使用GPU,可去掉--no_cuda
参数,并指定GPU设备。
bash
python interact.py --model_path path_to_your_model --device 0 --max_history_len 3
模型下载
项目提供的预训练模型可从以下链接下载: - 百度网盘【提取码:ju6m】 - GoogleDrive
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】