TOP > カテゴリ > Python・人工知能・機械学習 >

Tensorflow detection model zooの「学習済みモデル」をTensorFlow.jsで動かす

Python/TensorFlowの使い方(目次)

Tensorflow detection model zooにある 「ssd_mobilenet_v1_coco」を転移学習で「顔検出モデル」にした学習済みモデルをTensorFlow.jsで動かしてみます。

※顔検出モデルは顔検出、顔識別(顔認識)に挑戦してみるの2章で作成したモデルです。

Web-friendly formatに変換

はじめに「顔検出モデル」のFrozen Model(frozen_inference_graph.pb)をTensorFlow.js用にWeb-friendly formatに変換します。

tensorflowjs_converter  --input_format=tf_frozen_model --output_node_names='detection_boxes,detection_scores,detection_classes,num_detections'  --saved_model_tags=serve frozen_inference_graph.pb  ./web_model

※関連記事:TensorFlow.jsのHello World

後はTensorFlow.jsで推論するだけです。

ソースコード

画像を読み込むと、自動的に顔を検出して矩形で囲みます。

<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.12.5"></script>
<script>

// キャンバス
var src_canvas; 
var src_ctx;

// イメージ
var image;
   
///// イベント
    
window.onload = function (){
  
  // IE判定
  var userAgent = window.navigator.userAgent.toLowerCase();
  if(userAgent.indexOf('msie') != -1 || userAgent.indexOf('trident') != -1) {
    alert('Internet Explorerでは動作しません。');
  }
      
  src_canvas = document.getElementById("SrcCanvas");
  src_ctx = src_canvas.getContext("2d");    
  
  image = document.getElementById("img_source");
}

// ドラッグオーバー
function onDragOver(event){ 
  event.preventDefault(); 
} 
  
// ドロップ    
function onDrop(event){
  onAddFile(event);
  event.preventDefault(); 
}  

// ユーザーによりファイルが追加された  
function onAddFile(event) {
  var files;
  var errflg = false;
  var reader = new FileReader();
  
  if(event.target.files){
    files = event.target.files;
  }else{ 
    files = event.dataTransfer.files;   
  }    

  // ファイルが読み込まれた
  reader.onload = function (event) {
    
    // イメージが読み込まれた
    image.onload = function (){
      
      console.log(image.width + ' x ' + image.height);    
           
      // モデルの読み込み
      tf.loadFrozenModel("model/tensorflowjs_model.pb","model/weights_manifest.json")    
        .then(function(model){         
           
           // tf.tidyはメモリリークを回避する為に使用する定型メソッドです。
           z = tf.tidy(function(){

               src_canvas.width  = image.width;
               src_canvas.height = image.height;
               src_ctx.drawImage(image,0,0);           
               var imagedata = src_ctx.getImageData(0, 0,image.width, image.height);
               
               // 画像データを取得する 
               var index =0;
               var dRow,dCol;
               var raw = new Array(image.width *  image.height *3);
                 
               for (var y = 0; y < image.height; y++) {
                 dRow= (y * image.width * 3);
                 for (var x = 0;x < image.width; x++) {
                   dCol = dRow + (x * 4);
                   raw[index++] = imagedata.data[dCol];
                   raw[index++] = imagedata.data[dCol+1];
                   raw[index++] = imagedata.data[dCol+2];
                 }
               }
               
               // 配列から1階テンソルの生成 
               x = tf.tensor1d(raw);

               // モデルに「制御フロー」または「動的シェイプ操作」が含まれている為、非同期で推論を行う
               model.executeAsync({"image_tensor": x.reshape([1, image.height, image.width , 3])})
                .then(function(result){  
               
                   console.log(result); 

                   detection_boxes = result[0].reshape([result[0].size]).dataSync();
                   detection_scores = result[1].reshape([result[1].size]).dataSync();
                   detection_classes = result[2].reshape([result[2].size]).dataSync();
                   num_detections = result[3].reshape([result[3].size]).dataSync();
                   
                   // テスト用
                   console.log(detection_boxes);
                   console.log(detection_scores);
                   console.log(detection_classes);
                   console.log(num_detections);
                   
                   console.log(result[0].print());
                   console.log(result[1].print());                          

                   // データはtop,left,bottom,rightの順番なので注意
                   var top    = detection_boxes[0] * image.height;
                   var left   = detection_boxes[1] * image.width;
                   var bottom = detection_boxes[2] * image.height;
                   var right  = detection_boxes[3] * image.width;
                  
                   // 矩形の描画                  
                   src_ctx.strokeStyle = "rgb(255, 255, 255)";   
                   src_ctx.beginPath();
                   src_ctx.strokeRect(left, top,right-left ,bottom-top);
              });
              
              
           });                    

      }).catch(function(err){
         document.write(err);    
      });    

    };      
       
    // イメージが読み込めない
    image.onerror  = function (){
      alert('このファイルは読み込めません。');  
    };

    image.src = reader.result;       
  };
  
  if (files[0]){    
    run_flg = true;    
    reader.readAsDataURL(files[0]); 
    filename = files[0].name;
    document.getElementById("inputfile").value = '';
  }
}      

</script> 
</head>
<body ondrop="onDrop(event);" ondragover="onDragOver(event);"> 
<input type="file" id="inputfile" accept="image/jpeg,image/png,image/gif,image/bmp,image/x-icon" onchange="onAddFile(event);">
<canvas id="SrcCanvas" style="margin:0;paddling:0;"></canvas>
<img id="img_source" style="display:;">
</body>
</html> 

さて、本来ならば画像付きでご紹介するのですが、今回はありません。

なぜならば、Tensorflow.jsがまだ対応していないのです。

現在のTensorflow.js(0.12.5)からの戻り値は「常」に次のようになります。

// detection_boxes
Tensor
    [[[0.2497606, 0.119187 , 0.7361705, 0.8404452],
      [0.5242211, 0.0935305, 0.8911508, 0.8584976],
      [0.8173539, 0.3000037, 1        , 0.9805864],
      ...,
      [0.0671041, 0.6525283, 0.3557873, 0.7662106],
      [0.2164454, 0.7507971, 0.8537828, 0.9608495],
      [0.3435455, 0.0791326, 0.5383347, 0.8567387]]] 

// detection_scores
Tensor
     [[0.001851, 0.0015872, 0.0015419, ..., 0.0001173, 0.0001145, 0.0001144],]

どんな画像を読み込んでも、常にこの値です。

[summarize_graphの結果]

inputs/outputsは間違えていません。

Found 1 possible inputs: (name=image_tensor, type=uint8(4), shape=[?,?,?,3]) 
Found 4 possible outputs: (name=detection_boxes, op=Identity) (name=detection_scores, op=Identity) (name=detection_classes, op=Identity) (name=num_detections, op=Identity) 

[ソースコードに関して]

基本的にはTensorFlow Mobileのソースコードと同様にしているので、本来はこれで動作するハズなんです。

Tensorflow.jsは更新が早いので、1か月も待てば対応されると思います。
※現在は8月25日





関連記事



公開日:2018年08月25日
記事NO:02720