项目简介
本项目是一个基于Python和TensorFlow的图像分类系统,借助TF - Slim库构建和训练深度学习模型。支持自定义图片数据的准备、转换、模型训练、评估,以及模型的导出和预测。
项目的主要特性和功能
- 数据准备:支持自定义图片数据准备,并将其转换为TF - Record格式方便后续处理。
- 模型训练:使用MobileNet模型进行图像分类训练,可自定义训练参数,如学习率、批量大小等。
- 模型评估:具备模型评估功能,用于验证训练后模型的性能。
- 模型导出:支持将训练好的模型导出为可用于预测的格式。
- 图像预测:提供单张图片的预测功能,可使用导出的模型进行实时预测。
安装使用步骤
环境准备
- 确保已安装Python 3.6和TensorFlow GPU版本1.6.0。
- 安装其他依赖库,如TF - Slim。
数据准备
- 将自定义图片数据放入
data_prepare/pic/train
和data_prepare/pic/validation
目录中。 - 在
data_prepare/
目录下执行data_convert.py
脚本,将图片数据转换为TF - Record格式。
模型训练
- 在
slim/
目录下执行train_image_classifier.py
脚本,进行模型训练。 - 根据需要调整训练参数,如学习率、批量大小等。
模型评估
在slim/
目录下执行eval_image_classifier.py
脚本,进行模型评估。
模型导出
- 在
slim/
目录下执行export_inference_graph.py
脚本,导出训练好的模型。 - 在项目根目录下执行
freeze_graph.py
脚本,将模型冻结为可用于预测的格式。
图像预测
在项目根目录下执行classify_image_test.py
脚本,对单张图片进行预测。
示例
```bash python data_prepare/data_convert.py -t pic/ --train-shards 2 --validation-shards 2 --num-threads 2 --dataset-name satellite
python slim/train_image_classifier.py --train_dir=flowers/train_log --dataset_name=flowers --train_image_size=299 --dataset_split_name=train --dataset_dir=data --model_name="mobilenet_v2_140" --checkpoint_path=model/mobilenet_v2_1.4_224.ckpt --checkpoint_exclude_scopes=MobilenetV2/Logits,MobilenetV2/AuxLogits --trainable_scopes=MobilenetV2/Logits,MobilenetV2/AuxLogits --max_number_of_steps=20000 --batch_size=16 --learning_rate=0.001 --learning_rate_decay_type=fixed --log_every_n_steps=10 --optimizer=rmsprop --weight_decay=0.00004 --label_smoothing=0.1 --num_clones=1 --num_epochs_per_decay=2.5 --moving_average_decay=0.9999 --learning_rate_decay_factor=0.98 --preprocessing_name="inception_v2"
python slim/eval_image_classifier.py --checkpoint_path=flowers/train_log --eval_dir=flowers/eval_log --dataset_name=flowers --dataset_split_name=validation --dataset_dir=data --model_name="mobilenet_v2_140" --batch_size=32 --num_preprocessing_threads=2 --eval_image_size=299
python slim/export_inference_graph.py --alsologtostderr --model_name="mobilenet_v2_140" --image_size=299 --output_file=flowers/export/mobilenet_v2_140_inf_graph.pb --dataset_name flowers python freeze_graph.py --input_graph slim/flowers/export/mobilenet_v2_140_inf_graph.pb --input_checkpoint slim/flowers/train_log/model.ckpt-20000 --input_binary true --output_node_names MobilenetV2/Predictions/Reshape_1 --output_graph slim/flowers/export/frozen_graph.pb
python classify_image_test.py --model_path slim/flowers/export/frozen_graph.pb --label_path data_prepare/pic/label.txt --image_file test_image.jpg ```
注意事项
- 请确保按照TensorFlow和所需库的正确版本进行安装。
- 根据实际需求选择合适的模型和预处理函数。
- 模型的训练和评估过程可能需要大量的计算资源,请确保具有足够的硬件支持。
- 在实际使用中,请根据具体任务和数据集调整模型配置和预处理参数。
下载地址
点击下载 【提取码: 4003】【解压密码: www.makuang.net】