项目简介
本项目是基于PyTorch实现的模型蒸馏联邦学习框架(FedMD),用于解决多个参与者私有数据稀缺问题。每个参与者只需少量私有数据,可先在公共数据集上预训练,再在私有数据集上微调,最终基于公共数据集的模型输出实现模型知识的共享与集成。
项目的主要特性和功能
- 模型蒸馏与联邦学习融合:通过知识蒸馏技术,各参与者模型依据公共数据集输出的类别分数通信,实现模型知识的共享和集成。
- 迁移学习助力:先在公共数据集上充分训练模型,再在私有数据集上微调,有效提升模型性能。
- 多用户协同支持:支持多个用户(每个用户私有数据量少)同时参与模型训练。
- 训练过程可视化:提供训练损失和准确率的可视化,方便调试和性能分析。
安装使用步骤
环境准备
确保已安装Python 3.7.6、torch 1.6.0和torchvision 0.7.0。
数据准备
准备MNIST和FEMNIST数据集,以及对应的私有数据集。
模型训练
- 运行
pretrained_public_mnist_initial.py
对模型进行预训练。 - 运行
private_model_femnist_balanced.py
在私有数据集上训练模型。
模型评估
使用pretrained_public_mnist_Accuracy.py
、private_model_femnist_Accuracy.py
和collaborative_model_femnist_Accuracy.py
等脚本评估模型性能。
模型集成
运行collaborative_train_balanced_mnist.py
和Collaborative_step.py
进行模型集成和协作式训练。
结果可视化
利用matplotlib库绘制训练损失和准确率曲线,分析模型性能。
实际使用时需根据具体环境和需求调整上述步骤。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】