TensorFlow MobileのHello World [スマホでAIモデルを実行する]
TensorFlowで作成した「学習済みモデル」をAndroidで実行する方法です。
前提条件
TensorFlow.jsのHello World [WebでAIモデルを実行する] |
と同様な事を行います。環境設定などがお済でない方は先にご覧ください。
1. モデル(チェックポイント、ログ、PBファイル)の作成
次のコードをJupyter Notebookで実行します。
import tensorflow as tf with tf.name_scope('X'): x = tf.placeholder(tf.int32) with tf.name_scope('Y'): y = tf.Variable(3) with tf.name_scope('Z'): z = tf.add(x, y) saver = tf.train.Saver() init =tf.global_variables_initializer() with tf.Session() as sess: tf.summary.FileWriter("logs", sess.graph) tf.train.write_graph(sess.graph_def, './', 'graph.pbtxt') sess.run(init) result = sess.run(z, feed_dict={x:5}) saver.save(sess, 'ckpt/my_model') result
TensorBoardで確認するとグラフは次のようになります。

2. summarize_graphで入出力ノードを検査する
--in_graphのファイルパスは適宜、変更してください。
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/foo/graph.pbtxt
入力ノードは「X/Placeholder」。出力ノードは「Z/Add」となっています。

ただし、summarize_graphはあくまでも検査(予測)なのでTensorBoardでも再確認して下さい。また、graph.pbtxtファイルの中身はテキスト形式なのでそちらでも確認可能です。
3. freeze_graphでPBファイルとチェックポイントファイルを固めて「Frozen Model」にする
frozen_graph.pb(Frozen Model形式)を作成します。
bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/foo/graph.pbtxt --input_checkpoint=/foo/ckpt/my_model --output_graph=/foo/frozen_graph.pb --output_node_names=Z/Add
TensorFlow.jsでは「Frozen Model」を「Web-friendly format」に変換しましたが、スマホではFrozen Modelをそのまま使用します。
4. Android Studioでプロジェクトを作成する
Android Studioで新規プロジェクトを作成します。
4-1. build.gradleの設定
build.gradle(モジュール:app)の下部に次のコードを追記します。
allprojects { repositories { jcenter() } } dependencies { api 'org.tensorflow:tensorflow-android:+' }
これだけで、モバイルでTensorFlowが使用できるようになります。
4-2. frozen_graph.pbを取り込む
OS側の操作で\app\src\mainにassetsフォルダを作成して、その中にfrozen_graph.pbを移動します。
Android Stduioでは次のように表示されます。

4-3. ソースコード
import android.app.AlertDialog; import android.support.v7.app.AppCompatActivity; import android.os.Bundle; import android.view.View; import org.tensorflow.contrib.android.TensorFlowInferenceInterface; public class MainActivity extends AppCompatActivity { @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); findViewById(R.id.button).setOnClickListener(new View.OnClickListener() { @Override public void onClick(View v) { TensorFlowInferenceInterface inferenceInterface = new TensorFlowInferenceInterface(getAssets(), "frozen_graph.pb"); int[] inputs = {5}; int[] outputs = new int[1]; String[] outputNames = {"Z/Add"}; // 入力データを設定する inferenceInterface.feed("X/Placeholder", inputs); // モデルの推論(実行) inferenceInterface.run(outputNames); // モデルから結果を取得する inferenceInterface.fetch(outputNames[0], outputs); AlertDialog.Builder alertDialogBuilder = new AlertDialog.Builder(MainActivity.this); alertDialogBuilder.setTitle("結果"); alertDialogBuilder.setMessage(String.valueOf(outputs[0])); alertDialogBuilder.setPositiveButton("OK", null); alertDialogBuilder.show(); } }); } }
[解説]
ボタンを押すとメッセージボックスに8が表示されます。
これは、元のモデルでは
となっているからです。xはプレースフォルダ(placeholder)、yは変数(Variable)で3で定義されています。
xのプレースフォルダはAndroid側で渡す入力値です。
なので、21行目の「int[] inputs = {5};」を「int[] inputs = {7};」にすると10が表示されます。
[TensorFlowInferenceInterface]
現在の所、TensorFlowInferenceInterfaceクラスについては詳細な情報がありません。Googleさんで検索しても約 2,360 件しかヒットしません。
なので、feed/run/fetchの各メソッドの宣言をまとめてみました。
// feed系 public void feed(String inputName, boolean[] src, long... dims) public void feed(String inputName, float[] src, long... dims) public void feed(String inputName, int[] src, long... dims) public void feed(String inputName, long[] src, long... dims) public void feed(String inputName, double[] src, long... dims) public void feed(String inputName, byte[] src, long... dims) public void feedString(String inputName, byte[] src) public void feedString(String inputName, byte[][] src) public void feed(String inputName, FloatBuffer src, long... dims) public void feed(String inputName, IntBuffer src, long... dims) public void feed(String inputName, LongBuffer src, long... dims) public void feed(String inputName, DoubleBuffer src, long... dims) public void feed(String inputName, ByteBuffer src, long... dims) // run系 public void run(String[] outputNames) public void run(String[] outputNames, boolean enableStats) public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) // fetch系 public void fetch(String outputName, float[] dst) public void fetch(String outputName, int[] dst) public void fetch(String outputName, long[] dst) public void fetch(String outputName, double[] dst) public void fetch(String outputName, byte[] dst) public void fetch(String outputName, FloatBuffer dst) public void fetch(String outputName, IntBuffer dst) public void fetch(String outputName, LongBuffer dst) public void fetch(String outputName, DoubleBuffer dst) public void fetch(String outputName, ByteBuffer dst)
引数さえわかれば、なんとかなりますね :-)
関連記事
この記事を書いた人
![]() | 💻 ITスキル・経験 サーバー構築からWebアプリケーション開発。IoTをはじめとする電子工作、ロボット、人工知能やスマホ/OSアプリまで分野問わず経験。 画像処理/音声処理/アニメーション、3Dゲーム、会計ソフト、PDF作成/編集、逆アセンブラ、EXE/DLLファイルの書き換えなどのアプリを公開。詳しくは自己紹介へ |
プチモンテ代表、アーティスト名:プチモンテ | |
🎵 音楽制作 BGMは楽器(音源)さえあれば、何でも制作可能。歌モノは主にロック、バラード、ポップスを制作。歌詞は叙情詩、叙情的な楽曲が多い。楽曲制作は2023年12月中旬 ~ |
オリジナル曲を始めました✨
YouTubeで各楽曲を公開しています🌈
https://www.youtube.com/@petitmonte