项目简介
本项目是基于联邦学习和生成对抗网络(GAN)的神经网络训练系统,名为 GanSFL。其主要目的是处理分布式数据集上的非独立同分布(Non-IID)特征。项目运用了来自不同域的数据集,如 MNIST、SVHN、USPS、SynthDigits、MNIST_M,并采用联邦学习方式在客户端之间共享模型更新。
项目的主要特性和功能
- 支持在多个客户端之间进行模型训练,各客户端持有不同的数据集。
- 引入 GAN 的鉴别器,提升模型在非独立同分布数据上的性能。
- 对 ResNet18 模型在不同点切割,实验其对训练效果的影响。
- 研究不同数据集大小对模型训练效果的作用。
- 实验不同梯度权重对模型训练的影响。
- 研究以不同客户端输出的特征图作为目标域对训练效果的影响。
安装使用步骤
- 安装依赖:
- 安装
pytorch
和torchvision
。 - 安装
wandb
用于可视化和参数管理。
- 安装
- 配置环境:
- 运行
wandb login
登录 wandb 账号,或运行wandb offline
关闭 wandb。
- 运行
- 数据准备:
- 新建
data
文件夹,将下载的数据集放入其中。 - 在代码目录中新建
images/GanSFL
和images/SFL
文件夹,用于保存训练过程中输出的特征图。
- 新建
- 配置参数:
- 修改
config-defaults.yaml
文件中的配置参数,以适应不同的实验需求。
- 修改
- 运行代码:
- 按照 README 中的说明运行代码,进行模型训练和实验。
- 结果分析:
- 通过 wandb 查看详细的实验对比结果,分析不同参数和条件下的模型性能。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】