项目简介
本项目基于Python和TensorFlow框架,围绕快手自主开发的集成生存分析软件KwaiSurvival开展。该软件能让使用者在Python编程环境下,高效利用生存分析模型进行大规模数据分析。项目对多个深度学习生存分析模型进行测试,具备模型训练、预测、评估等功能,同时探索了数据处理、模型保存等方面。
项目的主要特性和功能
- 多模型集成:集成DeepSurv、DeepHit和DeepMultiTasks三个深度学习生存分析模型,满足不同生存分析场景需求。
- 模型评估:支持用Harrell’ concordance index(C - index)评估模型预测结果准确性,衡量模型区分能力。
- 数据处理:可借助其他开源项目的数据进行测试,有一定数据适配能力。
- 模型保存:为自定义的loss / Transform添加get_config函数,解决模型保存报错问题。
- 训练与预测:能对模型进行训练,对新数据进行预测并得到生存函数预测结果。
- 可视化:可绘制KM曲线,直观展示生存函数随时间的变化情况。
安装使用步骤
环境准备
假设用户已下载本项目的源码文件,需确保系统已安装Python环境,并按如下命令安装指定版本的TensorFlow:
pip install --pre tensorflow==2.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
若在Windows 10系统上升级tf - cpu版本报错ImportError:DLL load failed:找不到指定模块
,需下载并安装Visual Studio 2015、2017和2019。
代码运行
- 模型训练与保存(以DeepSurv为例): ```python import pandas as pd from DeepSurv import DeepSurv
df = pd.read_csv('test/KwaiSurvival/demo/example_data.csv') label = 'Time' event = 'Event'
ds = DeepSurv(df, label, event)
epochs = 100 ds.train(epochs)
ds.model.save('test/KwaiSurvival/path_to_my_model_DeepSurv.h5')
2. **模型预测**:
python
scores = ds.predict_score(ds.X)
3. **模型评估**:
python
ds.concordance_eval(X=None, event_times=None, event_observed=None)
ds.concordance_eval(X=ds.X, event_times=ds.label, event_observed=ds.event)
ds.concordance_eval(X=df.iloc[:4, :4].to_numpy(), event_times=df[label][:4], event_observed=df[event][:4])
4. **生存函数预测与可视化**:
python
survival_function = ds.predict_survival_func(ds.X)
ds.plot_survival_func(ds.X, 1000) ```
需注意,项目目前处于实验版,部分功能可能存在问题,如DeepMultiTasks模型训练时可能出现loss为nan的情况,使用时可根据实际情况进行调试和优化。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】