TensorFlow(TF-Slim)で簡単に転移学習を試す


Pocket

以前TensorFlow r0.12で追加された”TensorBoard: Embedding Visualization”を使って、きゅうり画像に対する畳み込み層で抽出した特徴量を可視化してみたのですが、Inception-v3も割ときゅうり仕分けのための特徴量を抽出できてるなと。

・自分で作った4層のCNN
ev_my

・Inception-v3(pool3の出力)
ev_inspv3

Inception-v3が抽出している特徴量でも、きゅうりの等級を割と区別できてるように思える。

ということで、転移学習を試してみたくなったので、そのやり方を調べてみました。

TensorFlow/modelsで用意されてた!?

手始めに、学習済みモデルがないかなと探していたのですが、TensorFlow/modelsリポジトリにありました。

Pre-trained Models
ILSVRC-2012-CLSのデータセットで学習させた、既存ニューラルネットワークのcheckpointファイルが公開されています。

公開モデル一覧
図:公開モデル一覧

 

しかも簡単に試すことができる!?

しかも、公開されているモデルで転移学習を試してみるためのコードも用意されていたので、早速試してみました。

1.Checkpointファイルをダウンロード

まずは、Pre-trained Modelsのリンクから、Checkpointファイルをダンロードして、/tmp/ckptフォルダに解凍して置いておきます。
今回は、vgg_16.tar.gzでやってみました。

2.リポジトリから取得

次に、modelsリポジトリを取ってきます。

git clone https://github.com/tensorflow/models.git
cd models/slim

3.データセットをダウンロード

model/slimでは、Cifar10,MNIST,ImageNet,Flowersといったデータセットが既に用意されています。
Preparing the datasets
今回は、Flowersデータセットを使うことにしました。
嬉しいことに、データセットの取得も下記コマンド1発で出来るという。

python download_and_convert_data.py --dataset_name=flowers --dataset_dir="/tmp/flowers"

4.転移学習をやってみる

転移学習も下記コマンドで試してみることができます。

python train_image_classifier.py \
  --train_dir=data \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=/tmp/flowers \
  --model_name=vgg_16 \
  --checkpoint_path=/tmp/ckpt/vgg_16.ckpt \
  --checkpoint_exclude_scopes=vgg_16/fc8 \
  --trainable_scopes=vgg_16/fc8 \
  --max_number_of_steps=1000 \
  --batch_size=32 \
  --learning_rate=0.01 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=100 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

長いですが、ハイパーパラメータを色々設定出来るようになっています。
転移学習時のポイントは、次の2点です。
–checkpoint_exclude_scopes:checkpointファイルからレストアしない
–trainable_scopes:訓練対象のVariableをスコープ単位で指定

上記の例では、VGG16の最後の全結合層(出力層)fc8を再学習しています。
VGG16構成
図:VGG16のニューラルネットワーク構成(fc8は最後の層)
VGG16構成(fc8)
図:fc8の内部

fc8内のweightsとbiasesが再学習対象になります。
複数層を指定したい場合は、コンマで区切って指定すればOKです。
–checkpoint_exclude_scopes=vgg_16/fc8,vgg_16/fc7,vgg_15/fc6
–trainable_scopes=vgg_16/fc8,vgg_16/fc7,vgg_15/fc6
みたいに。

 
1000ステップ回してみましたが、5時間ぐらいかかったでしょうか。
vgg_training
まだ全然lossが収束してないかんじです…というか、再学習対象をfc8だけにしたのが間違ってるきがします。
※全結合層すべて対象にしてやり直すと時間かかるから今回はパスorz

5.テストデータで検証する

最後に、テストデータで検証してみます。
これも下記コマンドで出来ます。簡単!!!

python eval_image_classifier.py \
  --alsologtostderr \
  -- checkpoint_path=data/model.ckpt-1000 \
  --dataset_dir=/tmp/flowers \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --model_name=vgg_16
結果

2500枚の花画像を5クラス(daisy,dandelion,roses,sunflowers,tulips)に識別するタスクにおいて、85%の正答率とい結果が得られました。

 

あとがき

次はきゅうり画像でもやってみます。
転移学習についてはもうちょっと詳しく勉強してみたいなと思っている今日このごろ。

下記論文だと、マンゴーを認識するフィルタは、リンゴも上手く認識できるみたいな結果が出てておもしろい。
しかも学習画像100枚程度で80%の認識率が出てるし!と思ったけど、転移学習しない場合でも同じ程度のパフォーマンスが出てた。
Deep Fruit Detection in Orchards

画像認識ってイデアを探る研究なのかな。

ではでは〜。

Leave a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です