ホーム > カテゴリ > Python・人工知能・Django >

TensorFlow.jsのHello World [WebでAIモデルを実行する]

Python/TensorFlowの使い方(目次)

TensorFlow.jsはWeb上でAIのモデルを実行することが可能な画期的なライブラリです。ブラウザ標準のWebGLで「GPUを使用する」ので処理が高速化されています。

今回はTensorFlow(Python)でモデルを作成して、そのモデルをWeb-friendly formatに変換してTensorFlow.jsで動作できるようにします。

前提条件

次のインストール、ビルドを行っているものとします。

TensorFlow.jsコンバータをインストールする
summarize_graph/freeze_graphのビルドを行う

また、TensorBoardを起動できる状態にして下さい。

TensorBoardとJupyter Notebookを同時に起動する
TensorBoardに表示されているログをリセット(初期化)する

1. モデル(チェックポイント、ログ、PBファイル)の作成

次のコードをJupyter Notebookで実行します。

import tensorflow as tf

with tf.name_scope('X'):
    x = tf.placeholder(tf.int32)

with tf.name_scope('Y'):
    y = tf.Variable(3)
    
with tf.name_scope('Z'):    
    z = tf.add(x, y)

saver = tf.train.Saver()

init =tf.global_variables_initializer()
with tf.Session() as sess:
    tf.summary.FileWriter("logs", sess.graph)  
    tf.train.write_graph(sess.graph_def, './', 'graph.pbtxt')
    
    sess.run(init)
    result = sess.run(z, feed_dict={x:5})   
    
    saver.save(sess, 'ckpt/my_model')            
result    

TensorBoardで確認するとグラフは次のようになります。

2. summarize_graphで入出力ノードを検査する

--in_graphのファイルパスは適宜、変更してください。

bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/foo/graph.pbtxt

入力ノードは「X/Placeholder」。出力ノードは「Z/Add」となっています。

ただし、summarize_graphはあくまでも検査(予測)なのでTensorBoardでも再確認して下さい。また、graph.pbtxtファイルの中身はテキスト形式なのでそちらでも確認可能です。

3. freeze_graphでPBファイルとチェックポイントファイルを固めて「Frozen Model」にする

frozen_graph.pb(Frozen Model形式)を作成します。

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/foo/graph.pbtxt --input_checkpoint=/foo/ckpt/my_model --output_graph=/foo/frozen_graph.pb --output_node_names=Z/Add 

4. Frozen ModelをWeb-friendly formatに変換する

カレントディレクトリを/fooに移動する。

TensorFlow.jsコンバータで変換する。

tensorflowjs_converter  --input_format=tf_frozen_model --output_node_names='Z/Add'  --saved_model_tags=serve frozen_graph.pb  ./web_model

次のようなファイルが生成されます。

この3つのファイルをダウンロードします。そして、modelフォルダを作成してその中にファイルを移動します。

5. JavaScriptでモデルを実行する

[test.html]

<!DOCTYPE html>
<html lang="ja">
<head>
<meta charset="utf-8">
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.12.5"></script>
<!-- <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script> -->
</head>
<body>
<script>
 
  tf.loadFrozenModel("model/tensorflowjs_model.pb","model/weights_manifest.json")
    .then(function(model){
     
     // tf.tidyはメモリリークを回避する為に使用する定型メソッドです。
     z = tf.tidy(function(){
       // モデルの推論
       return model.execute({"X/Placeholder": tf.scalar(5, 'int32')}).dataSync();     
     })
     
     alert(z)
    })    
    .catch(function(err){
       document.write(err);    
    });     
</script>
</body>
</html>

[ブラウザ]

Chromeではローカルで実行するとセキュリティ的にエラーになります。Chromeを使用する場合はWebサーバーにアップロードしてから実行します。

面倒な方はFireFoxまたはEdgeならばローカルで実行可能です。

※残念な事にIE11はローカル、Webサーバー共に実行できないようです。

[解説]

メッセージボックスに8が表示されたと思います。これは、元のモデルでは

z = x + y

となっているからです。xはプレースフォルダ(placeholder)、yは変数(Variable)で3で定義されています。

z = x + 3

xのプレースフォルダはJavaScript側の「モデルの推論」で渡す入力値です。

なので、17行目の「tf.scalar(5, 'int32')」を「tf.scalar(7, 'int32')」にすると10が表示されます。

注意事項

2018年7月31日のTensorFlow.jsの最新版は「0.12.0」です。

恐らくこの時点でのモデルの読み込みはtf.loadModel()を使用するのが一般的です。ただし、このtf.loadModel()では「Frozen Model」形式は読み込めないようです。

Frozen Modelを指定すると、次のようなエラーがでます。

Error: Missing field "modelTopology" from model JSON at path./model/weights_manifest.json
Error: The JSON from HTTP path ./model/weights_manifest.json contains neither model topology or manifest for weights.

なので、今回はtf.loadFrozenModel()という非公開メソッドを使用しています。将来的には正式のAPIとして公開される予定のようです。

2018/8/21追記 TensorFlow.js(0.12.5)でtf.loadFrozenModel()が正式に対応されました。ですので0.12.5以降のバージョンを使用して下さい。

WebでAIのテスト

手書き文字認識(数字) - MNIST

参考サイト

Preparing models for mobile deployment (公式)
モバイル配備のためのモデルを準備する (日本語訳)





関連記事



公開日:2018年07月31日 最終更新日:2018年08月24日
記事NO:02709