zukucode
主にWEB関連の情報を技術メモとして発信しています。

TensorFlow.jsでMNISTのアプリを作成する(初心者向け)

TensorFlow.jsVue.jsで、手書きの数字を予測するMNISTのWEBアプリを作成して、TensorFlow.jsの基本的な動作を学習します。

完成したイメージです。

mnistアプリのスクリーンショット

TensorFlow.jsでMNIST学習済モデルを読み込みブラウザで手書き文字認識をするを参考にさせていただきました。

モデル作成の部分は上記のページの通りに作成できましたが、予測の処理については、TensorFlow.jsのバージョンの違いでAPI仕様が大幅に変更されていたため、独自実装となりました。

本記事でのバージョンは以下になっています。

@tensorflow/tfjs
1.2.8
vue
2.6.10

モデルの入手

参考ページのほぼ流用ですが、Google Colaboratoryで以下の手順でコマンドを実行して、学習済みモデルをダウンロードします。

以下のコマンドで、kerasのサンプルプログラムを取得してします。

「学習したモデルをmnist_cnn_model.h5というファイル名で保存する処理」をソースの最後に追加しています。

!wget https://raw.githubusercontent.com/keras-team/keras/master/examples/mnist_cnn.py
!echo "model.save('mnist_cnn_model.h5')" >> mnist_cnn.py

サンプルプログラムのソースは以下で確認できます。

https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py

取得した学習スクリプトを実行します。

学習の過程が確認でき、最終的には0.9907の精度のモデルが作成されたことが確認できます。

!python3 mnist_cnn.py
60000/60000 [==============================] - 15s 250us/step - loss: 0.2535 - acc: 0.9229 - val_loss: 0.0565 - val_acc: 0.9795
Epoch 2/12
60000/60000 [==============================] - 8s 139us/step - loss: 0.0855 - acc: 0.9744 - val_loss: 0.0409 - val_acc: 0.9854
Epoch 3/12
60000/60000 [==============================] - 8s 139us/step - loss: 0.0655 - acc: 0.9808 - val_loss: 0.0305 - val_acc: 0.9893
Epoch 4/12
60000/60000 [==============================] - 8s 139us/step - loss: 0.0527 - acc: 0.9842 - val_loss: 0.0304 - val_acc: 0.9897
Epoch 5/12
60000/60000 [==============================] - 8s 139us/step - loss: 0.0470 - acc: 0.9862 - val_loss: 0.0292 - val_acc: 0.9907
Epoch 6/12
60000/60000 [==============================] - 8s 140us/step - loss: 0.0414 - acc: 0.9875 - val_loss: 0.0285 - val_acc: 0.9901
Epoch 7/12
60000/60000 [==============================] - 8s 140us/step - loss: 0.0366 - acc: 0.9890 - val_loss: 0.0281 - val_acc: 0.9901
Epoch 8/12
60000/60000 [==============================] - 8s 139us/step - loss: 0.0344 - acc: 0.9894 - val_loss: 0.0277 - val_acc: 0.9909
Epoch 9/12
60000/60000 [==============================] - 9s 142us/step - loss: 0.0313 - acc: 0.9907 - val_loss: 0.0254 - val_acc: 0.9909
Epoch 10/12
60000/60000 [==============================] - 8s 140us/step - loss: 0.0289 - acc: 0.9912 - val_loss: 0.0253 - val_acc: 0.9911
Epoch 11/12
60000/60000 [==============================] - 8s 138us/step - loss: 0.0275 - acc: 0.9913 - val_loss: 0.0273 - val_acc: 0.9917
Epoch 12/12
60000/60000 [==============================] - 8s 139us/step - loss: 0.0266 - acc: 0.9915 - val_loss: 0.0280 - val_acc: 0.9907
Test loss: 0.028025911626653396
Test accuracy: 0.9907

tensorflowjs_converterTensorFlow.jsで読み込める形式に変換するためにtensorflowjsをインストールします。

(インストール時にエラーが発生してしまいましたが、無視してそのまま次のステップを実行しても問題ありませんでした)

!pip3 install tensorflowjs

tensorflowjsをインストール後、以下のコマンドで、TensorFlow.jsで読み込める形式に変換します。

!tensorflowjs_converter --input_format keras mnist_cnn_model.h5 model

変換したモデルをzipにして、自分のPCにダウンロードします。

!zip -r model.zip model && ls -l
import google.colab
google.colab.files.download('model.zip')

以上でモデルの準備は完了です。

キャンバスの作成

次にWEBアプリを作成します。

まずは手書きができるキャンバスを作成します。

細かな制御を省いた最低限のものは以下のようにします。

判定に使用する画像は黒背景に白い数字である必要があるため、canvasの背景色を黒にして白で手書きができるようにしています。

Canvas.vue
<template>
  <div>
    <div>
      <canvas
        ref="canvas"
        class="canvas"
        :width="size.width"
        :height="size.height"
        @mousedown="handleMouseDown"
        @mouseup="handleMouseUp"
        @mousemove="handleMouseMove"
      ></canvas>
    </div>
    <a class="button" @click="clear">clear</a>
  </div>
</template>

<style scoped>
.canvas {
  background-color: #000; /*黒背景*/
}
</style>

<script>
export default {
  data() {
    return {
      size: {
        width: 400,
        height: 400,
      },
      mouse: {
        x: 0,
        y: 0,
        down: false,
      },
    };
  },
  computed: {
    currentMouse() {
      const c = this.$refs.canvas;
      const rect = c.getBoundingClientRect();

      return {
        x: this.mouse.x - rect.left,
        y: this.mouse.y - rect.top,
      };
    },
  },
  mounted() {
    this.clear();
  },
  methods: {
    draw() {
      if (this.mouse.down) {
        const ctx = this.$refs.canvas.getContext('2d');
        ctx.lineTo(this.currentMouse.x, this.currentMouse.y);
        ctx.strokeStyle = '#fff'; // 白文字
        ctx.lineWidth = 20;
        ctx.stroke();
      }
    },
    handleMouseDown(event) {
      this.mouse = {
        x: event.pageX,
        y: event.pageY,
        down: true,
      };

      const ctx = this.$refs.canvas.getContext('2d');
      ctx.moveTo(this.currentMouse.x, this.currentMouse.y);
    },
    handleMouseUp() {
      this.mouse.down = false;
    },
    handleMouseMove(event) {
      Object.assign(this.mouse, {
        x: event.pageX,
        y: event.pageY,
      });

      this.draw();
    },
    clear() {
      const ctx = this.$refs.canvas.getContext('2d');
      ctx.clearRect(0, 0, this.size.width, this.size.height);
      ctx.beginPath();
    },
  },
};
</script>

モデルの読込

前ステップでダウンロードしたモデルを、公開フォルダなどに配置して、以下のように読み込みます。

const model = await tf.loadLayersModel('model/model.json'); // モデルのパスを指定して読み込む

モデルの判定

今回作成したモデルで推論を行うには[batchSize, height, width, colorChannels]の形式で入力(引数)に指定する必要があります。

今回は手書きの画像1枚が対象なので、batchSize1となります、

widthheightは、学習で使用した画像(28×28)と同じサイズにするため、両方28となります。

colorChannelsも、学習で使用した画像は白黒画像のため、1となります。

以下のように入力形式に変換します。

import * as tf from '@tensorflow/tfjs';

const input = tf.browser
    .fromPixels(this.$refs.canvas, 1)
    .toFloat()
    .resizeNearestNeighbor([28, 28])
    .div(tf.scalar(255))
    .expandDims();

browser.fromPixel

キャンバスのデータをtf.Tensorオブジェクトに変換します。

browser.fromPixels(データ, チャンネル数)の形式で指定します。

データにはキャンバスのhtml要素をそのまま指定できます。

試しにこの段階での入力データを確認してみると、400, 400は画像サイズで1はチャンネル数となっています。

const input = tf.browser.fromPixels(this.$refs.canvas, 1)
console.log(input.shape);
// [400, 400, 1]

toFloat

配列のタイプをfloat(float32)に変換します。

resizeNearestNeighbor

画像サイズを変換します。

resizeNearestNeighbor([height, width])の形式で変換します。

入力形式は28×28なので、resizeNearestNeighbor([28, 28])となります。

試しにこの段階での入力データを確認してみると、画像サイズが28×28になっていることが確認できます。

const input = tf.browser.fromPixels(this.$refs.canvas, 1).toFloat().resizeNearestNeighbor([28, 28])
console.log(input.shape);
// [28, 28, 1]

div

現段階でのデータは0〜255で表現されています。これを0〜1にして正則化する必要があるので、255で割ります。

div(割る値)の形式で指定します。

割る値もtensorflowで扱う型で指定します。tf.scalar(数値)とすると、tensorflowで扱える形式で数値を宣言できます。

試しに正則化の前後でデータの値を比較してみます。dataSync()でデータを表示できます。

784(28×28)のfloat配列になっていることが確認できます。

const input = tf.browser.fromPixels(this.$refs.canvas, 1).toFloat().resizeNearestNeighbor([28, 28])
console.log(input.dataSync());
// Float32Array(784) [0, 255, 0,  …]

const input2 = tf.browser.fromPixels(this.$refs.canvas, 1).toFloat().resizeNearestNeighbor([28, 28]).div(tf.scalar(255))
console.log(input2.dataSync());
// Float32Array(784) [0, 1, 0,  …]

expandDims

最後に、データの先頭に1(batchSize)を追加します。

expandDims(追加する場所)の形式で指定します。引数省略時は0(先頭)に追加されます。

const input = tf.browser
    .fromPixels(this.$refs.canvas, 1)
    .toFloat()
    .resizeNearestNeighbor([28, 28])
    .div(tf.scalar(255))
    .expandDims();
console.log(input.shape);
// [1, 28, 28, 1]

これでモデルのに対応した入力形式に変換できましたので、推論をおこないます。

predictで推論を行い、dataSyncで推論結果を取得しています。

推論結果は要素数10の1次元配列になっています。

以下の例ではscore[1]の値が0.99...となっているため、99%の確率で「1」であると推論しています。

const score = model.predict(input).dataSync();

console.log(score);
// Float32Array(10) [0.0002011491742450744, 0.9975615739822388, 0.000036233086575521156, 0.000020055860659340397, 0.00005023562698625028, 0.0004250735801178962, 0.0010217257076874375, 0.00016747607151046395, 0.00039549331995658576, 0.00012142454943386838]

メモリリークの回避

不要なtf.Tensorオブジェクトは明示的に破棄していかないとメモリリークをしてしまうようです。

tf.Tensorオブジェクトを扱う処理はすべてtf.tidyで行うようにすれば自動的に破棄してくれるため、最終的には以下のような実装となります。

const score = tf.tidy(() => {
    const input = tf.browser
    .fromPixels(payload.imageData, 1)
    .toFloat()
    .resizeNearestNeighbor([28, 28])
    .div(tf.scalar(255))
    .expandDims();
    return model.predict(input).dataSync();
});

関連記事

  • webpack lessをImportしてビルドする

    Vue.js lessを使いwebpackでビルドするでvueファイルの中でlessを実装してビルドする方法を紹介しました。htmlやbodyなどに適用するベースのクラスを外部のlessファイルに実装...


  • ExtractTextPluginでcssファイルを出力する

    webpackのプラグインExtractTextPluginを使って、ビルドされたjsファイルからstyleの部分を抽出してcssファイルで出力します。extract-text-webpack-plu...


  • webpack-dev-serverで開発サーバーを起動する

    webpack-dev-serverで開発サーバーを起動します。ファイルやフォルダ構成などの環境はvueファイル(単一ファイルコンポーネント)をwebpackでビルドで紹介したものと同様とします。we...


  • webpack モジュールのパスを絶対パスで指定する方法

    自分で作成したモジュールをインポートするときはインポートするファイルを基準に相対パスで指定する必要があります。フォルダ構成によっては深く階層を辿らないといけないので、フォルダ構成が変わってしまうと大変...


  • Vuex 厳格(strict)モードでエラーになるよくある原因

    Vuexの厳格(strict)モードでエラーになってしまう原因でよくあるパターンです。stateに格納した配列をソートして表示するときに、stateの値をcomputedやgettersでそのままソー...


  • Vuex stateを作成して各コンポーネントで参照する

    Vue.jsにVuexを導入するでVuexの導入が完了しました。実際にstoreを作成して、各コンポーネントでその値を参照できるようにします。最初にstoreを作成します。mutations.jsにs...


  • Vuex mutationでstateの値を変更する

    Vuex stateを作成して各コンポーネントで参照するでstateを作成して各コンポーネントで参照できるようになりました。stateはmutationの処理でしか変更できません。コンポーネントからm...


  • Vuex moduleに定義したstateをコンポーネントで簡単に取得する

    Vuexのモジュールに定義したstateを各コンポーネントで参照する方法を紹介します。以下のようにmoduleが定義されているとします。各コンポーネントではthis.$store.state.モジュー...


  • Vuexとaxiosを連携する

    Vuex actionでローディングアイコンを表示するで、actionの処理でローディングの表示非表示を制御するmutationを実行するようにしました。ajaxの処理が増えるたびに同じmutatio...


  • Vuex モジュールのactionの処理で別モジュールのactionをコールする

    Vuexのモジュール内のactionの処理で別モジュールのactionをコールする方法を紹介します。ルートのactionの場合、moduleで定義したactionをコールする場合は、以下のようにモジュ...