XLAコンパイラAPI

コレクションでコンテンツを整理 必要に応じて、コンテンツの保存と分類を行います。

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

TensorFlowとXLAライブラリをインポートします。XLAには、一部または全てのモデルを XLA でコンパイルする実験的なAPIである xla.compile() が含まれています。

import tensorflow as tf

from tensorflow.contrib.compiler import xla

必要ないくつかの定数を定義し、 MNISTのデータセットを用意します。

# それぞれの入力イメージの大きさは、 28 x 28ピクセル
IMAGE_SIZE = 28 * 28
# 個別の数字のラベル [0..9] の個数
NUM_CLASSES = 10
# それぞれのトレーニングバッチ(ステップ)での標本数
TRAIN_BATCH_SIZE = 100
# トレーニングステップを実行する回数
TRAIN_STEPS = 1000
# MNISTデータセットをロードする。
train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()
test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE)

iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes)
images, labels = iterator.get_next()
images = tf.reshape(images, [-1, IMAGE_SIZE])
images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
WARNING:tensorflow:From <ipython-input-4-4b4b30f2fbb2>:5: DatasetV1.output_types (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_types(dataset)`.
WARNING:tensorflow:From <ipython-input-4-4b4b30f2fbb2>:5: DatasetV1.output_shapes (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_shapes(dataset)`.

モデルを構築する関数の定義

以下のコードブロックは、順伝搬と逆伝搬の両方を行う、1つのdenseレイヤーを持つ簡単なモデルを構築する関数を含みます。

コードが呼ばれたとき、2つの値を返します。 y は、それぞれのターゲットのクラスの予測確率を表す tf.Tensor です。 train_stepglobal_step の値を増加し、変数の更新を行う tf.Operation です。

def build_mnist_model(x, y_):
  y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)

  cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)
  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

  return y, train_step

XLA の有効化

XLA を有効化するには build_mnist_model 関数を xla.compile に渡します。以下のコードブロックは、モデルを xla.compile() 関数でラップします。これにより、提供された入力を持つターゲット関数をXLAで実行できます。

[y] = xla.compile(build_mnist_model, inputs=[images, labels])
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/losses/losses_impl.py:121: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

グラフをコンパイルするとき、XLAはターゲット関数によって構築されたグラフの全てのノードを、いくつかのXLAのオペレータで置き換えます。

xla.compileは、生成されたXLAのオペレータから独立して実行できる tf.Operation を返しません 代わりに、ターゲット関数から返された tf.Operation ノードは、返された全ての tf.Tensor の値との制御依存関係として追加されます。これにより、 返されたテンソルが評価されるときに、 tf.Operation ノードの実行をトリガします。

擬似コードによるxla.compileの実装は、以下のようになります:


# TensorFlowに、XLAが扱いやすい方法でコードを実行するよう依頼する

y, train_step = build_mnist_model(images, labels)
with tf.control_dependencies([train_step]):
  y = tf.identity(y)

# TensorFlowに、XLAが扱いやすい方法でコードの実行を停止するよう依頼する

xla.compile()は常に tf.Tensor のリスト(1要素しか無かったとしても)を返します。

もしあなたが構築したグラフを今表示したら、通常のTensorFlowのグラフとそれほど変わらないことがわかり、前に述べたXLAのオペレータを見つけることができないでしょう。これは、あなたが sess.run() でグラフを実行しようとしても、実際のコンパイルは後ほど発生するからです。後ほど、TensorFlowは実際にXLAオペレータを生成する一連のグラフ書き換えパスをトリガーします。これは、すべての入力がそろったときに、計算をコンパイルして実行します。

モデルの学習とテスト

# セッションを作成しすべての変数を初期化。
# xla.compile()は、Keras model.fit() APIやTF eager modeとはまだ動作しません。
sess = tf.Session()
sess.run(tf.global_variables_initializer())

以下のコードブロックはモデルを学習します。 y の評価は、制御依存関係がある train_step をトリガします。これは、モデル変数を更新します。

# 学習用データセットを与える
sess.run(iterator.make_initializer(train_ds))

# TRAIN_STEPS ステップだけ実行する
for i in range(TRAIN_STEPS):
  sess.run(y)

print("Model trained for %s steps." % TRAIN_STEPS)
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/iterator_ops.py:348: Iterator.output_types (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_types(iterator)`.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/iterator_ops.py:349: Iterator.output_shapes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_shapes(iterator)`.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/iterator_ops.py:351: Iterator.output_classes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_classes(iterator)`.
Model trained for 1000 steps.
# 学習済みモデルをテストする

# テスト用データセットを与える
sess.run(iterator.make_initializer(test_ds))

# 精度を計算する
correct_prediction = tf.equal(tf.argmax(y, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Prediction accuracy after training: %s" % sess.run(accuracy))
Prediction accuracy after training: 0.91
# セッションを片付ける
sess.close()