Study Blog

自分の興味関心が向いたものを、好きな時好きなだけ気分で勉強したことを記すブログ

Raspberry Pi4 画像認識 ~⑥独自画像学習 独自画像の学習~

アノテーションCSVを変更し、Tensorflow model makerの物体検出で使用できる形に変更しました。
今回は、この変更したCSVを用いてモデルを学習させます。

1.下準備

基本的に、Tensorflowの公式ドキュメント
Object Detection with TensorFlow Lite Model Maker
に沿って行います。

!pip install -q --use-deprecated=legacy-resolver tflite-model-maker
!pip install -q pycocotools
!pip install -q opencv-python-headless==4.1.2.30

google colabにtflite model maker / pycocotools / opencvをインストール

import numpy as np
from pathlib import Path

from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector

import tensorflow as tf
assert tf.__version__.startswith('2')

tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)

ライブラリインポート

ANNO_CSV = Path("※アノテーションdataset.csvのパス")
SAVE_PATH = Path("※今回制作するmodelの保存パス")

パスの指定

spec = model_spec.get('efficientdet_lite0')

ここで、物体検出モデルを指定します。ここでは、EfficientNetというニューラルネットワークモデルをtfliteで利用できるように変換したものを指定しています。
公式ではefficientdet_lite0~4をサポートしているようで、数字が大きくなるほど精度が高くなり、推論速度が遅くなるようです。

train_data, validation_data, test_data = object_detector.DataLoader.from_csv(ANNO_CSV)

前回修正したCSVファイルを使って、訓練データ、評価データ、テストデータを作成

model = object_detector.create(train_data, model_spec=spec, batch_size=8, train_whole_model=True, validation_data=validation_data)

ここで、学習モデルを作成します。
画像枚数によりますが、じゃんけん画像の場合20~30分程度時間がかかります。
メモリ不足になる場合は、「batch_size=」の部分の数字を小さくするとうまくいこことが多いです。

batch_size=8は、8枚の画像を1セットとして学習させることをさしています。
batch_sizeを大きくすれば、まとめてたくさんの画像を学習でき、batch_sizeを小さくすれば、少量ずつ画像を学習させることができます。
画像が多い場合、batch_sizeを大きくすることで学習時間を短縮することもできます。
適宜変更して対応してみてください。

Epoch 1/50
16/16 [==============================] - 73s 3s/step - det_loss: 1.6946 - cls_loss: 1.1217 - box_loss: 0.0115 - reg_l2_loss: 0.0633 - loss: 1.7579 - learning_rate: 0.0090 - gradient_norm: 1.3597 - val_det_loss: 1.5367 - val_cls_loss: 1.0551 - val_box_loss: 0.0096 - val_reg_l2_loss: 0.0633 - val_loss: 1.6001
Epoch 2/50
16/16 [==============================] - 40s 2s/step - det_loss: 1.3411 - cls_loss: 0.9364 - box_loss: 0.0081 - reg_l2_loss: 0.0633 - loss: 1.4044 - learning_rate: 0.0100 - gradient_norm: 1.7222 - val_det_loss: 1.1773 - val_cls_loss: 0.8345 - val_box_loss: 0.0069 - val_reg_l2_loss: 0.0633 - val_loss: 1.2406
Epoch 3/50
16/16 [==============================] - 39s 2s/step - det_loss: 0.9503 - cls_loss: 0.6632 - box_loss: 0.0057 - reg_l2_loss: 0.0634 - loss: 1.0137 - learning_rate: 0.0099 - gradient_norm: 1.8717 - val_det_loss: 0.9387 - val_cls_loss: 0.5932 - val_box_loss: 0.0069 - val_reg_l2_loss: 0.0634 - val_loss: 1.0021
Epoch 4/50
16/16 [==============================] - 40s 2s/step - det_loss: 0.7808 - cls_loss: 0.5625 - box_loss: 0.0044 - reg_l2_loss: 0.0634 - loss: 0.8442 - learning_rate: 0.0099 - gradient_norm: 1.9342 - val_det_loss: 0.9141 - val_cls_loss: 0.4919 - val_box_loss: 0.0084 - val_reg_l2_loss: 0.0634 - val_loss: 0.9775
Epoch 5/50
16/16 [==============================] - 46s 3s/step - det_loss: 0.6327 - cls_loss: 0.4590 - box_loss: 0.0035 - reg_l2_loss: 0.0634 - loss: 0.6961 - learning_rate: 0.0098 - gradient_norm: 2.4829 - val_det_loss: 0.6877 - val_cls_loss: 0.4158 - val_box_loss: 0.0054 - val_reg_l2_loss: 0.0634 - val_loss: 0.7511
Epoch 6/50
16/16 [==============================] - 40s 2s/step - det_loss: 0.5148 - cls_loss: 0.3654 - box_loss: 0.0030 - reg_l2_loss: 0.0634 - loss: 0.5782 - learning_rate: 0.0097 - gradient_norm: 2.0064 - val_det_loss: 0.6624 - val_cls_loss: 0.3615 - val_box_loss: 0.0060 - val_reg_l2_loss: 0.0634 - val_loss: 0.7258
Epoch 7/50
16/16 [==============================] - 39s 2s/step - det_loss: 0.4151 - cls_loss: 0.2857 - box_loss: 0.0026 - reg_l2_loss: 0.0634 - loss: 0.4786 - learning_rate: 0.0096 - gradient_norm: 1.8357 - val_det_loss: 0.5232 - val_cls_loss: 0.3218 - val_box_loss: 0.0040 - val_reg_l2_loss: 0.0634 - val_loss: 0.5866
Epoch 8/50
16/16 [==============================] - 40s 3s/step - det_loss: 0.4025 - cls_loss: 0.2825 - box_loss: 0.0024 - reg_l2_loss: 0.0635 - loss: 0.4659 - learning_rate: 0.0094 - gradient_norm: 2.3049 - val_det_loss: 0.6151 - val_cls_loss: 0.3453 - val_box_loss: 0.0054 - val_reg_l2_loss: 0.0635 - val_loss: 0.6786
Epoch 9/50
16/16 [==============================] - 40s 3s/step - det_loss: 0.3461 - cls_loss: 0.2444 - box_loss: 0.0020 - reg_l2_loss: 0.0635 - loss: 0.4096 - learning_rate: 0.0093 - gradient_norm: 2.0001 - val_det_loss: 0.4147 - val_cls_loss: 0.2451 - val_box_loss: 0.0034 - val_reg_l2_loss: 0.0635 - val_loss: 0.4782
Epoch 10/50
16/16 [==============================] - 41s 3s/step - det_loss: 0.3328 - cls_loss: 0.2338 - box_loss: 0.0020 - reg_l2_loss: 0.0635 - loss: 0.3963 - learning_rate: 0.0091 - gradient_norm: 1.9760 - val_det_loss: 0.3180 - val_cls_loss: 0.2151 - val_box_loss: 0.0021 - val_reg_l2_loss: 0.0635 - val_loss: 0.3815

実際の学習過程が出力されていて、学習が正常に行われているかを確認することができます。
val_lossの数値が徐々に小さくなっていることが確認できれば大方うまくいってると思って大丈夫です。

うまく学習できていない場合、val_loss部分が初めから「0.000」で変化しない場合が多いと思います。
その場合は、アノテーションCSV部分で誤りがあることが大半だと思います。
アノテーションCSVのx/yの数字を見直したり、画像パスを見直したりしましょう。

model.evaluate(test_data)
1/1 [==============================] - 6s 6s/step

{'AP': 0.91557753,
 'AP50': 1.0,
 'AP75': 1.0,
 'AP_/Scissors': 0.9158746,
 'AP_/label': -1.0,
 'AP_/paper': 0.93646866,
 'AP_/rock': 0.89438945,
 'APl': 0.91557753,
 'APm': -1.0,
 'APs': -1.0,
 'ARl': 0.9261905,
 'ARm': -1.0,
 'ARmax1': 0.91952384,
 'ARmax10': 0.9261905,
 'ARmax100': 0.9261905,
 'ARs': -1.0}

モデルを評価して確認

model.export(export_dir=SAVE_PATH)
model.evaluate_tflite(str(SAVE_PATH) + '/model.tflite', test_data)
18/18 [==============================] - 38s 2s/step

{'AP': 0.9018482,
 'AP50': 1.0,
 'AP75': 1.0,
 'AP_/Scissors': 0.87204623,
 'AP_/label': -1.0,
 'AP_/paper': 0.93646866,
 'AP_/rock': 0.8970297,
 'APl': 0.9018482,
 'APm': -1.0,
 'APs': -1.0,
 'ARl': 0.91190475,
 'ARm': -1.0,
 'ARmax1': 0.91190475,
 'ARmax10': 0.91190475,
 'ARmax100': 0.91190475,
 'ARs': -1.0}

学習モデルをgoogle driveに保存して、テストデータを使ってtfliteモデルを最終評価します。

          • 関連記事-----

melostark.hatenablog.com
melostark.hatenablog.com
melostark.hatenablog.com
melostark.hatenablog.com
melostark.hatenablog.com
melostark.hatenablog.com
melostark.hatenablog.com
melostark.hatenablog.com