TOP > カテゴリ > Python・人工知能・機械学習 >

TensorFlow MobileのHello World [スマホでAIモデルを実行する]

Python/TensorFlowの使い方(目次)

TensorFlowで作成した「学習済みモデル」をAndroidで実行する方法です。

前提条件

TensorFlow.jsのHello World [WebでAIモデルを実行する]

と同様な事を行います。環境設定などがお済でない方は先にご覧ください。

1. モデル(チェックポイント、ログ、PBファイル)の作成

次のコードをJupyter Notebookで実行します。

import tensorflow as tf

with tf.name_scope('X'):
    x = tf.placeholder(tf.int32)

with tf.name_scope('Y'):
    y = tf.Variable(3)
    
with tf.name_scope('Z'):    
    z = tf.add(x, y)

saver = tf.train.Saver()

init =tf.global_variables_initializer()
with tf.Session() as sess:
    tf.summary.FileWriter("logs", sess.graph)  
    tf.train.write_graph(sess.graph_def, './', 'graph.pbtxt')
    
    sess.run(init)
    result = sess.run(z, feed_dict={x:5})   
    
    saver.save(sess, 'ckpt/my_model')            
result    

TensorBoardで確認するとグラフは次のようになります。

2. summarize_graphで入出力ノードを検査する

--in_graphのファイルパスは適宜、変更してください。

bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/foo/graph.pbtxt

入力ノードは「X/Placeholder」。出力ノードは「Z/Add」となっています。

ただし、summarize_graphはあくまでも検査(予測)なのでTensorBoardでも再確認して下さい。また、graph.pbtxtファイルの中身はテキスト形式なのでそちらでも確認可能です。

3. freeze_graphでPBファイルとチェックポイントファイルを固めて「Frozen Model」にする

frozen_graph.pb(Frozen Model形式)を作成します。

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/foo/graph.pbtxt --input_checkpoint=/foo/ckpt/my_model --output_graph=/foo/frozen_graph.pb --output_node_names=Z/Add 

TensorFlow.jsでは「Frozen Model」を「Web-friendly format」に変換しましたが、スマホではFrozen Modelをそのまま使用します。

4. Android Studioでプロジェクトを作成する

Android Studioで新規プロジェクトを作成します。

4-1. build.gradleの設定

build.gradle(モジュール:app)の下部に次のコードを追記します。

allprojects {
    repositories {
        jcenter()
    }
}

dependencies {
    api 'org.tensorflow:tensorflow-android:+'
}

これだけで、モバイルでTensorFlowが使用できるようになります。

4-2. frozen_graph.pbを取り込む

OS側の操作で\app\src\mainにassetsフォルダを作成して、その中にfrozen_graph.pbを移動します。

Android Stduioでは次のように表示されます。

4-3. ソースコード

import android.app.AlertDialog;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.view.View;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;

public class MainActivity extends AppCompatActivity {

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        findViewById(R.id.button).setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View v) {
                TensorFlowInferenceInterface inferenceInterface =
                        new TensorFlowInferenceInterface(getAssets(),
                                "frozen_graph.pb");

                int[] inputs = {5};
                int[] outputs = new int[1];
                String[] outputNames = {"Z/Add"};

                // 入力データを設定する
                inferenceInterface.feed("X/Placeholder", inputs);

                // モデルの推論(実行)
                inferenceInterface.run(outputNames);

                // モデルから結果を取得する
                inferenceInterface.fetch(outputNames[0], outputs);

                AlertDialog.Builder alertDialogBuilder = new AlertDialog.Builder(MainActivity.this);
                alertDialogBuilder.setTitle("結果");
                alertDialogBuilder.setMessage(String.valueOf(outputs[0]));
                alertDialogBuilder.setPositiveButton("OK", null);
                alertDialogBuilder.show();
            }
        });
    }
}

[解説]

ボタンを押すとメッセージボックスに8が表示されます。

これは、元のモデルでは

z = x + y

となっているからです。xはプレースフォルダ(placeholder)、yは変数(Variable)で3で定義されています。

z = x + 3

xのプレースフォルダはAndroid側で渡す入力値です。

なので、21行目の「int[] inputs = {5};」を「int[] inputs = {7};」にすると10が表示されます。

[TensorFlowInferenceInterface]

現在の所、TensorFlowInferenceInterfaceクラスについては詳細な情報がありません。Googleさんで検索しても約 2,360 件しかヒットしません。

なので、feed/run/fetchの各メソッドの宣言をまとめてみました。

// feed系
public void feed(String inputName, boolean[] src, long... dims)
public void feed(String inputName, float[] src, long... dims)
public void feed(String inputName, int[] src, long... dims)
public void feed(String inputName, long[] src, long... dims)
public void feed(String inputName, double[] src, long... dims)
public void feed(String inputName, byte[] src, long... dims)
public void feedString(String inputName, byte[] src)
public void feedString(String inputName, byte[][] src)
public void feed(String inputName, FloatBuffer src, long... dims)
public void feed(String inputName, IntBuffer src, long... dims)
public void feed(String inputName, LongBuffer src, long... dims)
public void feed(String inputName, DoubleBuffer src, long... dims)
public void feed(String inputName, ByteBuffer src, long... dims)

// run系
public void run(String[] outputNames)
public void run(String[] outputNames, boolean enableStats)
public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames)

// fetch系
public void fetch(String outputName, float[] dst) 
public void fetch(String outputName, int[] dst)
public void fetch(String outputName, long[] dst)
public void fetch(String outputName, double[] dst)
public void fetch(String outputName, byte[] dst)
public void fetch(String outputName, FloatBuffer dst) 
public void fetch(String outputName, IntBuffer dst)
public void fetch(String outputName, LongBuffer dst)
public void fetch(String outputName, DoubleBuffer dst)
public void fetch(String outputName, ByteBuffer dst)

引数さえわかれば、なんとかなりますね :-)





関連記事



公開日:2018年08月21日 最終更新日:2018年08月24日
記事NO:02717