このページは Cloud Translation API によって翻訳されました。
Switch to English

XLAをtf.functionで使用する

TensorFlow.orgで見る Google Colabで実行 GitHubでソースを表示する

このチュートリアルでは、TensorFlowモデルをトレーニングしてMNISTデータセットを分類します。トレーニング関数はXLAを使用してコンパイルされます。

まず、TensorFlowを読み込み、積極的な実行を有効にします。

import tensorflow as tf
tf.compat.v1.enable_eager_execution()

次に、必要な定数を定義して、MNISTデータセットを準備します。

# Size of each input image, 28 x 28 pixels
IMAGE_SIZE = 28 * 28
# Number of distinct number labels, [0..9]
NUM_CLASSES = 10
# Number of examples in each training batch (step)
TRAIN_BATCH_SIZE = 100
# Number of training steps to run
TRAIN_STEPS = 1000

# Loads MNIST dataset.
train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()

# Casting from raw data to the required datatypes.
def cast(images, labels):
  images = tf.cast(
      tf.reshape(images, [-1, IMAGE_SIZE]), tf.float32)
  labels = tf.cast(labels, tf.int64)
  return (images, labels)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 1s 0us/step

最後に、モデルとオプティマイザを定義します。モデルは単一の高密度レイヤーを使用します。

layer = tf.keras.layers.Dense(NUM_CLASSES)
optimizer = tf.keras.optimizers.Adam()

トレーニング関数を定義する

トレーニング関数では、上で定義したレイヤーを使用して予測ラベルを取得し、オプティマイザーを使用して損失の勾配を最小化します。 XLAを使用して計算をコンパイルするには、 experimental_compile=True tf.functionしてtf.function内にtf.functionexperimental_compile=True

@tf.function(experimental_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

モデルをトレーニングしてテストする

トレーニング関数を定義したら、モデルを定義します。

for images, labels in train_ds:
  if optimizer.iterations > TRAIN_STEPS:
    break
  train_mnist(images, labels)

そして最後に、精度を確認します。

images, labels = cast(test[0], test[1])
predicted_labels = layer(images)
correct_prediction = tf.equal(tf.argmax(predicted_labels, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Prediction accuracy after training: %s" % accuracy)
Prediction accuracy after training: tf.Tensor(0.8829, shape=(), dtype=float32)