TensorFlow.jsのHello World [WebでAIモデルを実行する]
TensorFlow.jsはWeb上でAIのモデルを実行することが可能な画期的なライブラリです。ブラウザ標準のWebGLで「GPUを使用する」ので処理が高速化されています。
今回はTensorFlow(Python)でモデルを作成して、そのモデルをWeb-friendly formatに変換してTensorFlow.jsで動作できるようにします。
前提条件
次のインストール、ビルドを行っているものとします。
・TensorFlow.jsコンバータをインストールする |
・summarize_graph/freeze_graphのビルドを行う |
また、TensorBoardを起動できる状態にして下さい。
・TensorBoardとJupyter Notebookを同時に起動する |
・TensorBoardに表示されているログをリセット(初期化)する |
1. モデル(チェックポイント、ログ、PBファイル)の作成
次のコードをJupyter Notebookで実行します。
import tensorflow as tf with tf.name_scope('X'): x = tf.placeholder(tf.int32) with tf.name_scope('Y'): y = tf.Variable(3) with tf.name_scope('Z'): z = tf.add(x, y) saver = tf.train.Saver() init =tf.global_variables_initializer() with tf.Session() as sess: tf.summary.FileWriter("logs", sess.graph) tf.train.write_graph(sess.graph_def, './', 'graph.pbtxt') sess.run(init) result = sess.run(z, feed_dict={x:5}) saver.save(sess, 'ckpt/my_model') result
TensorBoardで確認するとグラフは次のようになります。
2. summarize_graphで入出力ノードを検査する
--in_graphのファイルパスは適宜、変更してください。
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/foo/graph.pbtxt
入力ノードは「X/Placeholder」。出力ノードは「Z/Add」となっています。
ただし、summarize_graphはあくまでも検査(予測)なのでTensorBoardでも再確認して下さい。また、graph.pbtxtファイルの中身はテキスト形式なのでそちらでも確認可能です。
3. freeze_graphでPBファイルとチェックポイントファイルを固めて「Frozen Model」にする
frozen_graph.pb(Frozen Model形式)を作成します。
bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/foo/graph.pbtxt --input_checkpoint=/foo/ckpt/my_model --output_graph=/foo/frozen_graph.pb --output_node_names=Z/Add
4. Frozen ModelをWeb-friendly formatに変換する
カレントディレクトリを/fooに移動する。
cd /foo
TensorFlow.jsコンバータで変換する。
tensorflowjs_converter --input_format=tf_frozen_model --output_node_names='Z/Add' --saved_model_tags=serve frozen_graph.pb ./web_model
次のようなファイルが生成されます。
この3つのファイルをダウンロードします。そして、modelフォルダを作成してその中にファイルを移動します。
5. JavaScriptでモデルを実行する
[test.html]
<!DOCTYPE html> <html lang="ja"> <head> <meta charset="utf-8"> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.12.5"></script> <!-- <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script> --> </head> <body> <script> tf.loadFrozenModel("model/tensorflowjs_model.pb","model/weights_manifest.json") .then(function(model){ // tf.tidyはメモリリークを回避する為に使用する定型メソッドです。 z = tf.tidy(function(){ // モデルの推論 return model.execute({"X/Placeholder": tf.scalar(5, 'int32')}).dataSync(); }) alert(z) }) .catch(function(err){ document.write(err); }); </script> </body> </html>
[ブラウザ]
Chromeではローカルで実行するとセキュリティ的にエラーになります。Chromeを使用する場合はWebサーバーにアップロードしてから実行します。
面倒な方はFireFoxまたはEdgeならばローカルで実行可能です。
※残念な事にIE11はローカル、Webサーバー共に実行できないようです。
[解説]
メッセージボックスに8が表示されたと思います。これは、元のモデルでは
となっているからです。xはプレースフォルダ(placeholder)、yは変数(Variable)で3で定義されています。
xのプレースフォルダはJavaScript側の「モデルの推論」で渡す入力値です。
なので、17行目の「tf.scalar(5, 'int32')」を「tf.scalar(7, 'int32')」にすると10が表示されます。
注意事項
2018年7月31日のTensorFlow.jsの最新版は「0.12.0」です。
恐らくこの時点でのモデルの読み込みはtf.loadModel()を使用するのが一般的です。ただし、このtf.loadModel()では「Frozen Model」形式は読み込めないようです。
Frozen Modelを指定すると、次のようなエラーがでます。
Error: Missing field "modelTopology" from model JSON at path./model/weights_manifest.json |
Error: The JSON from HTTP path ./model/weights_manifest.json contains neither model topology or manifest for weights. |
なので、今回はtf.loadFrozenModel()という非公開メソッドを使用しています。将来的には正式のAPIとして公開される予定のようです。
WebでAIのテスト
参考サイト
Preparing models for mobile deployment (公式)
モバイル配備のためのモデルを準備する (日本語訳)