PCで何時間もかけて学習したニューラルネットワークをAndroidやRaspberryPiで使いたい!って時にはProtocol Buffers形式でグラフを保存してあげる必要があります。
A Tool Developer’s Guide to TensorFlow Model Files
ということで、今回はその方法についてです。
checkpoint fileの様にsaveコマンド一発で保存できたらいいのですが、どうやらそんなお手軽コマンドはないようです。
…でネットで調べてみると下記の方法でできるようなので、早速試してみました。
【環境】
Ubuntu14.04
python2.7
Tensorflow 0.9.0
グラフの書き出し
手順は、
1.学習済みグラフを準備する(このtf.Graphをg_1とする)
2.g_1のvariablesをnumpy arrayに変換
3.新しいtf.Graph(これをg_2とする)を作って、Step2のnumpy arrayをtf.constantのテンソルに変換
4.g_2にg_1をコピーするけど、variablesの部分をStep3のtf.constantに置き換える
5.g_2を書き出す
です。
今回は学習が終わった後のcheckpointファイルから学習済みグラフを作って、それを書き出すという方法で実装しました。
(対象のグラフは80x80x3の画像を入力とし、最終的にsoftmaxの結果を出力するCNNです)
import tensorflow as tf
import numpy as np
import run_train import inference
g_1 = tf.Graph()
vars = {}
with g_1.as_default():
with tf.Session() as sess:
#グラフを再構築して、model.ckptから学習済みテンソルをレストアする
image_placeholder = tf.placeholder('float', shape=(None, 80, 80, 3), name='input_image')
keep_prob = tf.placeholder('float', name='keep_prob')
logits_op = inference(image_plachholder, keep_prob)
saver = tf.train.Saver()
saver.restore(sess, 'model.ckpt')
for v in tf.trainable_variables():
vars[v.value().name] = v.eval() #variableをnumpy arrayに変換
g_2 = tf.Graph()
consts = {}
with g_2.as_default():
with tf.Session() as sess:
for k in vars.keys():
consts[k] = tf.constant(vars[k])
tf.import_graph_def(g_1.as_graph_def(), input_map=consts, name="")
tf.train.write_graph(g_2.as_graph_def(), './', 'trained_graph.pb', as_text=False)
TensorFlow0.9.0から(?) tensorflow/python/framework/graph_utilが追加され、その中のconvert_variables_to_constantsを使っても同じことができます。
その場合は、
…学習済みグラフの再構築&学習済みテンソルのレストア output_graph_def = graph_util.convert_variables_to_constants(sess, g_1.as_graph_def(), ['Softmax']) tf.train.write_graph(output_graph_def, './', 'trained_graph.bp', as_text=False)
pbファイルの読み込み
pbファイルから学習済みグラフを読み込む方法です。
with tf.Graph().as_default():
with open('trained_graph.bp', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
出力の取得は、テンソル名で指定できる。
logits = sess.run('softmax:0', feed_dict={'input_image:0': images, 'keep_prob:0': 1.0})
グラフ内のテンソル名は下記で確認できる。
assing_ops = tg.Graph.get_operations(sess.graph)
for op in assign_ops:
print op.name
for output_name in op.outputs:
print "output : " , output.name #出力テンソルの名前
【ハマったとこ】ExponentialMovingAverageを使っているとインポートできない?
ExponentialMovingAverageを使ったグラフをimport_graph_defすると
ValueError: graph_def is invalid at node u'ExponentialMovingAverage/AssignMovingAvg': Input tensor 'moments/moments_1/mean/ExponentialMovingAverage:0' Cannot convert a tensor of type float32 to an input of type float32_ref.
が出る。原因が分からない(~_~)
あとがき
ただ、この方法だとdropoutとかいらない処理などの不要なものが残ってしまうので、最終的にはちゃんとグラフを作り直した方がいいような気がしてる。
さて、最近はラズパイで動かそうと思って色々試している途中…なかなかすんなりとはいかないものですね。




