XLAコンパイラAPI

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

TensorFlowとXLAライブラリをインポートします。XLAには、一部または全てのモデルを XLA でコンパイルする実験的なAPIである xla.compile() が含まれています。

import tensorflow as tf

from tensorflow.contrib.compiler import xla

必要ないくつかの定数を定義し、 MNISTのデータセットを用意します。

# それぞれの入力イメージの大きさは、 28 x 28ピクセル
IMAGE_SIZE = 28 * 28
# 個別の数字のラベル [0..9] の個数
NUM_CLASSES = 10
# それぞれのトレーニングバッチ(ステップ)での標本数
TRAIN_BATCH_SIZE = 100
# トレーニングステップを実行する回数
TRAIN_STEPS = 1000
# MNISTデータセットをロードする。
train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()
test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE)

iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes)
images, labels = iterator.get_next()
images = tf.reshape(images, [-1, IMAGE_SIZE])
images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
WARNING:tensorflow:From <ipython-input-4-4b4b30f2fbb2>:5: DatasetV1.output_types (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_types(dataset)`.
WARNING:tensorflow:From <ipython-input-4-4b4b30f2fbb2>:5: DatasetV1.output_shapes (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_shapes(dataset)`.

モデルを構築する関数の定義

以下のコードブロックは、順伝搬と逆伝搬の両方を行う、1つのdenseレイヤーを持つ簡単なモデルを構築する関数を含みます。

コードが呼ばれたとき、2つの値を返します。 y は、それぞれのターゲットのクラスの予測確率を表す tf.Tensor です。 train_stepglobal_step の値を増加し、変数の更新を行う tf.Operation です。

def build_mnist_model(x, y_):
  y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)

  cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)
  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

  return y, train_step

XLA の有効化

XLA を有効化するには build_mnist_model 関数を xla.compile に渡します。以下のコードブロックは、モデルを xla.compile() 関数でラップします。これにより、提供された入力を持つターゲット関数をXLAで実行できます。

[y] = xla.compile(build_mnist_model, inputs=[images, labels])
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/losses/losses_impl.py:121: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

グラフをコンパイルするとき、XLAはターゲット関数によって構築されたグラフの全てのノードを、いくつかのXLAのオペレータで置き換えます。

xla.compileは、生成されたXLAのオペレータから独立して実行できる tf.Operation を返しません 代わりに、ターゲット関数から返された tf.Operation ノードは、返された全ての tf.Tensor の値との制御依存関係として追加されます。これにより、 返されたテンソルが評価されるときに、 tf.Operation ノードの実行をトリガします。

擬似コードによるxla.compileの実装は、以下のようになります:


# TensorFlowに、XLAが扱いやすい方法でコードを実行するよう依頼する

y, train_step = build_mnist_model(images, labels)
with tf.control_dependencies([train_step]):
  y = tf.identity(y)

# TensorFlowに、XLAが扱いやすい方法でコードの実行を停止するよう依頼する

xla.compile()は常に tf.Tensor のリスト(1要素しか無かったとしても)を返します。

もしあなたが構築したグラフを今表示したら、通常のTensorFlowのグラフとそれほど変わらないことがわかり、前に述べたXLAのオペレータを見つけることができないでしょう。これは、あなたが sess.run() でグラフを実行しようとしても、実際のコンパイルは後ほど発生するからです。後ほど、TensorFlowは実際にXLAオペレータを生成する一連のグラフ書き換えパスをトリガーします。これは、すべての入力がそろったときに、計算をコンパイルして実行します。

モデルの学習とテスト

# セッションを作成しすべての変数を初期化。
# xla.compile()は、Keras model.fit() APIやTF eager modeとはまだ動作しません。
sess = tf.Session()
sess.run(tf.global_variables_initializer())

以下のコードブロックはモデルを学習します。 y の評価は、制御依存関係がある train_step をトリガします。これは、モデル変数を更新します。

# 学習用データセットを与える
sess.run(iterator.make_initializer(train_ds))

# TRAIN_STEPS ステップだけ実行する
for i in range(TRAIN_STEPS):
  sess.run(y)

print("Model trained for %s steps." % TRAIN_STEPS)
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/iterator_ops.py:348: Iterator.output_types (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_types(iterator)`.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/iterator_ops.py:349: Iterator.output_shapes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_shapes(iterator)`.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/iterator_ops.py:351: Iterator.output_classes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.data.get_output_classes(iterator)`.
Model trained for 1000 steps.
# 学習済みモデルをテストする

# テスト用データセットを与える
sess.run(iterator.make_initializer(test_ds))

# 精度を計算する
correct_prediction = tf.equal(tf.argmax(y, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Prediction accuracy after training: %s" % sess.run(accuracy))
Prediction accuracy after training: 0.91
# セッションを片付ける
sess.close()