以前TensorFlow r0.12で追加された”TensorBoard: Embedding Visualization”を使って、きゅうり画像に対する畳み込み層で抽出した特徴量を可視化してみたのですが、Inception-v3も割ときゅうり仕分けのための特徴量を抽出できてるなと。
Inception-v3が抽出している特徴量でも、きゅうりの等級を割と区別できてるように思える。
ということで、転移学習を試してみたくなったので、そのやり方を調べてみました。
TensorFlow/modelsで用意されてた!?
手始めに、学習済みモデルがないかなと探していたのですが、TensorFlow/modelsリポジトリにありました。
Pre-trained Models
ILSVRC-2012-CLSのデータセットで学習させた、既存ニューラルネットワークのcheckpointファイルが公開されています。
しかも簡単に試すことができる!?
しかも、公開されているモデルで転移学習を試してみるためのコードも用意されていたので、早速試してみました。
1.Checkpointファイルをダウンロード
まずは、Pre-trained Modelsのリンクから、Checkpointファイルをダンロードして、/tmp/ckptフォルダに解凍して置いておきます。
今回は、vgg_16.tar.gzでやってみました。
2.リポジトリから取得
次に、modelsリポジトリを取ってきます。
git clone https://github.com/tensorflow/models.git cd models/slim
3.データセットをダウンロード
model/slimでは、Cifar10,MNIST,ImageNet,Flowersといったデータセットが既に用意されています。
Preparing the datasets
今回は、Flowersデータセットを使うことにしました。
嬉しいことに、データセットの取得も下記コマンド1発で出来るという。
python download_and_convert_data.py --dataset_name=flowers --dataset_dir="/tmp/flowers"
4.転移学習をやってみる
転移学習も下記コマンドで試してみることができます。
python train_image_classifier.py \ --train_dir=data \ --dataset_name=flowers \ --dataset_split_name=train \ --dataset_dir=/tmp/flowers \ --model_name=vgg_16 \ --checkpoint_path=/tmp/ckpt/vgg_16.ckpt \ --checkpoint_exclude_scopes=vgg_16/fc8 \ --trainable_scopes=vgg_16/fc8 \ --max_number_of_steps=1000 \ --batch_size=32 \ --learning_rate=0.01 \ --learning_rate_decay_type=fixed \ --save_interval_secs=60 \ --save_summaries_secs=60 \ --log_every_n_steps=100 \ --optimizer=rmsprop \ --weight_decay=0.00004
長いですが、ハイパーパラメータを色々設定出来るようになっています。
転移学習時のポイントは、次の2点です。
–checkpoint_exclude_scopes:checkpointファイルからレストアしない
–trainable_scopes:訓練対象のVariableをスコープ単位で指定
上記の例では、VGG16の最後の全結合層(出力層)fc8を再学習しています。
図:VGG16のニューラルネットワーク構成(fc8は最後の層)
図:fc8の内部
fc8内のweightsとbiasesが再学習対象になります。
複数層を指定したい場合は、コンマで区切って指定すればOKです。
–checkpoint_exclude_scopes=vgg_16/fc8,vgg_16/fc7,vgg_15/fc6
–trainable_scopes=vgg_16/fc8,vgg_16/fc7,vgg_15/fc6
みたいに。
1000ステップ回してみましたが、5時間ぐらいかかったでしょうか。
まだ全然lossが収束してないかんじです…というか、再学習対象をfc8だけにしたのが間違ってるきがします。
※全結合層すべて対象にしてやり直すと時間かかるから今回はパスorz
5.テストデータで検証する
最後に、テストデータで検証してみます。
これも下記コマンドで出来ます。簡単!!!
python eval_image_classifier.py \ --alsologtostderr \ -- checkpoint_path=data/model.ckpt-1000 \ --dataset_dir=/tmp/flowers \ --dataset_name=flowers \ --dataset_split_name=validation \ --model_name=vgg_16
結果
2500枚の花画像を5クラス(daisy,dandelion,roses,sunflowers,tulips)に識別するタスクにおいて、85%の正答率とい結果が得られました。
あとがき
次はきゅうり画像でもやってみます。
転移学習についてはもうちょっと詳しく勉強してみたいなと思っている今日このごろ。
下記論文だと、マンゴーを認識するフィルタは、リンゴも上手く認識できるみたいな結果が出てておもしろい。
しかも学習画像100枚程度で80%の認識率が出てるし!と思ったけど、転移学習しない場合でも同じ程度のパフォーマンスが出てた。
Deep Fruit Detection in Orchards
画像認識ってイデアを探る研究なのかな。
ではでは〜。