ช่วยปกป้อง Great Barrier Reef กับ TensorFlow บน Kaggle เข้าร่วมท้าทาย

ใช้ TPU

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

ก่อนที่คุณจะทำงานนี้โน้ตบุ๊ค Colab ให้แน่ใจว่าการเร่งความเร็วฮาร์ดแวร์ของคุณเป็น TPU โดยการตรวจสอบการตั้งค่าโน้ตบุ๊คของคุณ: Runtime> เปลี่ยนประเภทรันไทม์> เร่งฮาร์ดแวร์> TPU

ติดตั้ง

import tensorflow as tf

import os
import tensorflow_datasets as tfds
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/requests/__init__.py:104: RequestsDependencyWarning: urllib3 (1.26.7) or chardet (2.3.0)/charset_normalizer (2.0.9) doesn't match a supported version!
  RequestsDependencyWarning)

การเริ่มต้น TPU

โดยทั่วไปแล้ว TPU เป็นผู้ปฏิบัติงาน Cloud TPU ซึ่งแตกต่างจากกระบวนการในเครื่องที่เรียกใช้โปรแกรม Python ของผู้ใช้ ดังนั้น คุณต้องดำเนินการเริ่มต้นบางอย่างเพื่อเชื่อมต่อกับคลัสเตอร์ระยะไกลและเริ่มต้น TPU หมายเหตุว่า tpu อาร์กิวเมนต์ tf.distribute.cluster_resolver.TPUClusterResolver ที่อยู่พิเศษเฉพาะสำหรับ Colab หากคุณกำลังเรียกใช้โค้ดบน Google Compute Engine (GCE) คุณควรส่งผ่านชื่อ Cloud TPU ของคุณแทน

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
# This is the TPU initialization code that has to be at the beginning.
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Clearing out eager caches
INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.26:8470
INFO:tensorflow:Initializing the TPU system: grpc://10.240.1.26:8470
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Finished initializing TPU system.
All devices:  [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]

การจัดวางอุปกรณ์ด้วยตนเอง

หลังจากที่เริ่มต้น TPU แล้ว คุณสามารถใช้การจัดวางอุปกรณ์ด้วยตนเองเพื่อวางการคำนวณบนอุปกรณ์ TPU เครื่องเดียวได้:

a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

with tf.device('/TPU:0'):
  c = tf.matmul(a, b)

print("c device: ", c.device)
print(c)
c device:  /job:worker/replica:0/task:0/device:TPU:0
tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32)

กลยุทธ์การจัดจำหน่าย

โดยปกติ คุณเรียกใช้โมเดลของคุณบน TPU หลายตัวในลักษณะคู่ขนานของข้อมูล ในการเผยแพร่โมเดลของคุณบน TPU หลายตัว (หรือตัวเร่งความเร็วอื่นๆ) TensorFlow ขอเสนอกลยุทธ์การกระจายหลายแบบ คุณสามารถแทนที่กลยุทธ์การจัดจำหน่ายของคุณและโมเดลจะทำงานบนอุปกรณ์ (TPU) ใดก็ตามที่กำหนด ตรวจสอบ คำแนะนำกลยุทธ์การจัดจำหน่าย สำหรับข้อมูลเพิ่มเติม

แสดงให้เห็นถึงนี้สร้าง tf.distribute.TPUStrategy วัตถุ:

strategy = tf.distribute.TPUStrategy(resolver)
INFO:tensorflow:Found TPU system:
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

ที่จะทำซ้ำการคำนวณจึงสามารถทำงานได้ในทุกแกน TPU คุณสามารถส่งมันลงไปใน strategy.run API ด้านล่างนี้เป็นตัวอย่างที่แสดงให้เห็นว่าแกนทั้งหมดที่ได้รับปัจจัยการผลิตเดียวกัน (a, b) และการแสดงคูณเมทริกซ์ในแต่ละหลักอิสระ ผลลัพธ์จะเป็นค่าจากการจำลองทั้งหมด

@tf.function
def matmul_fn(x, y):
  z = tf.matmul(x, y)
  return z

z = strategy.run(matmul_fn, args=(a, b))
print(z)
PerReplica:{
  0: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  1: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  2: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  3: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  4: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  5: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  6: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32),
  7: tf.Tensor(
[[22. 28.]
 [49. 64.]], shape=(2, 2), dtype=float32)
}

การจำแนกประเภทบน TPUs

เมื่อครอบคลุมแนวคิดพื้นฐานแล้ว ให้พิจารณาตัวอย่างที่เป็นรูปธรรมมากขึ้น ในส่วนนี้จะแสดงให้เห็นถึงวิธีการใช้การกระจาย strategy- tf.distribute.TPUStrategy -to ฝึกอบรมรุ่น Keras บนเมฆ TPU

กำหนดแบบจำลอง Keras

เริ่มต้นด้วยความหมายของการเป็น Sequential รุ่น Keras สำหรับการจำแนกประเภทของภาพในชุดข้อมูลที่ MNIST โดยใช้ Keras ไม่ต่างจากสิ่งที่คุณจะใช้หากคุณกำลังฝึกเกี่ยวกับ CPU หรือ GPU โปรดทราบว่าการสร้างความต้องการ Keras รุ่นต้องอยู่ภายใน strategy.scope ดังนั้นตัวแปรที่สามารถสร้างขึ้นบนอุปกรณ์ TPU แต่ละ ส่วนอื่นๆ ของโค้ดไม่จำเป็นต้องอยู่ในขอบเขตของกลยุทธ์

def create_model():
  return tf.keras.Sequential(
      [tf.keras.layers.Conv2D(256, 3, activation='relu', input_shape=(28, 28, 1)),
       tf.keras.layers.Conv2D(256, 3, activation='relu'),
       tf.keras.layers.Flatten(),
       tf.keras.layers.Dense(256, activation='relu'),
       tf.keras.layers.Dense(128, activation='relu'),
       tf.keras.layers.Dense(10)])

โหลดชุดข้อมูล

การใช้งานที่มีประสิทธิภาพของ tf.data.Dataset API เป็นสิ่งสำคัญเมื่อใช้ระบบคลาวด์ TPU ขณะที่มันเป็นไปไม่ได้ที่จะใช้ระบบคลาวด์ TPUs จนกว่าคุณจะสามารถป้อนข้อมูลได้อย่างรวดเร็วพอ คุณสามารถเรียนรู้เพิ่มเติมเกี่ยวกับประสิทธิภาพชุดข้อมูลใน คู่มือการปฏิบัติงานท่อขาเข้า

สำหรับทุกคน แต่การทดลองง่าย (ใช้ tf.data.Dataset.from_tensor_slices หรือข้อมูลในกราฟอื่น ๆ ) คุณจะต้องจัดเก็บไฟล์ข้อมูลทั้งหมดอ่านได้โดยชุดข้อมูลใน Google Cloud Storage (GCS) บุ้งกี๋

สำหรับกรณีการใช้งานส่วนใหญ่ก็จะแนะนำให้แปลงข้อมูลของคุณลงใน TFRecord รูปแบบและใช้ tf.data.TFRecordDataset ที่จะอ่านมัน ตรวจสอบ TFRecord และ tf.Example กวดวิชา สำหรับรายละเอียดเกี่ยวกับวิธีการทำเช่นนี้ มันไม่ได้เป็นความต้องการอย่างหนักและคุณสามารถใช้ผู้อ่านชุดข้อมูลอื่น ๆ เช่น tf.data.FixedLengthRecordDataset หรือ tf.data.TextLineDataset

คุณสามารถโหลดชุดข้อมูลขนาดเล็กทั้งหมดลงในหน่วยความจำโดยใช้ tf.data.Dataset.cache

โดยไม่คำนึงถึงรูปแบบข้อมูลที่ใช้ ขอแนะนำอย่างยิ่งให้คุณใช้ไฟล์ขนาดใหญ่ตามลำดับ 100MB นี่เป็นสิ่งสำคัญอย่างยิ่งในการตั้งค่าเครือข่ายนี้ เนื่องจากค่าใช้จ่ายในการเปิดไฟล์สูงขึ้นอย่างมาก

ดังแสดงในรหัสข้างล่างนี้คุณควรใช้ tensorflow_datasets โมดูลจะได้รับสำเนาของการฝึกอบรม MNIST และข้อมูลการทดสอบ โปรดทราบว่า try_gcs มีการระบุให้ใช้สำเนาที่มีอยู่ในถัง GCS สาธารณะ หากไม่ระบุ TPU จะไม่สามารถเข้าถึงข้อมูลที่ดาวน์โหลด

def get_dataset(batch_size, is_training=True):
  split = 'train' if is_training else 'test'
  dataset, info = tfds.load(name='mnist', split=split, with_info=True,
                            as_supervised=True, try_gcs=True)

  # Normalize the input data.
  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255.0
    return image, label

  dataset = dataset.map(scale)

  # Only shuffle and repeat the dataset in training. The advantage of having an
  # infinite dataset for training is to avoid the potential last partial batch
  # in each epoch, so that you don't need to think about scaling the gradients
  # based on the actual batch size.
  if is_training:
    dataset = dataset.shuffle(10000)
    dataset = dataset.repeat()

  dataset = dataset.batch(batch_size)

  return dataset

ฝึกโมเดลโดยใช้ API ระดับสูงของ Keras

คุณสามารถฝึกอบรมรูปแบบของคุณกับ Keras fit และ compile APIs ไม่มีอะไร TPU เฉพาะในเรื่องนี้เป็นขั้นตอนที่คุณเขียนรหัสราวกับว่าคุณกำลังใช้ GPUs mutliple และ MirroredStrategy แทน TPUStrategy คุณสามารถเรียนรู้มากขึ้นใน การฝึกอบรมกระจายกับ Keras กวดวิชา

with strategy.scope():
  model = create_model()
  model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['sparse_categorical_accuracy'])

batch_size = 200
steps_per_epoch = 60000 // batch_size
validation_steps = 10000 // batch_size

train_dataset = get_dataset(batch_size, is_training=True)
test_dataset = get_dataset(batch_size, is_training=False)

model.fit(train_dataset,
          epochs=5,
          steps_per_epoch=steps_per_epoch,
          validation_data=test_dataset, 
          validation_steps=validation_steps)
Epoch 1/5
300/300 [==============================] - 19s 32ms/step - loss: 0.1344 - sparse_categorical_accuracy: 0.9581 - val_loss: 0.0481 - val_sparse_categorical_accuracy: 0.9852
Epoch 2/5
300/300 [==============================] - 6s 21ms/step - loss: 0.0347 - sparse_categorical_accuracy: 0.9891 - val_loss: 0.0440 - val_sparse_categorical_accuracy: 0.9870
Epoch 3/5
300/300 [==============================] - 6s 21ms/step - loss: 0.0190 - sparse_categorical_accuracy: 0.9938 - val_loss: 0.0404 - val_sparse_categorical_accuracy: 0.9881
Epoch 4/5
300/300 [==============================] - 6s 21ms/step - loss: 0.0123 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.0366 - val_sparse_categorical_accuracy: 0.9895
Epoch 5/5
300/300 [==============================] - 6s 21ms/step - loss: 0.0100 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.0422 - val_sparse_categorical_accuracy: 0.9881
<keras.callbacks.History at 0x7f9d7c0e73c8>

เพื่อลดค่าใช้จ่ายหลามและเพิ่มประสิทธิภาพการทำงานของ TPU ของคุณผ่านใน argument- steps_per_execution -to Model.compile ในตัวอย่างนี้ จะเพิ่มปริมาณงานประมาณ 50%:

with strategy.scope():
  model = create_model()
  model.compile(optimizer='adam',
                # Anything between 2 and `steps_per_epoch` could help here.
                steps_per_execution = 50,
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['sparse_categorical_accuracy'])

model.fit(train_dataset,
          epochs=5,
          steps_per_epoch=steps_per_epoch,
          validation_data=test_dataset,
          validation_steps=validation_steps)
Epoch 1/5
300/300 [==============================] - 12s 41ms/step - loss: 0.1656 - sparse_categorical_accuracy: 0.9512 - val_loss: 0.0429 - val_sparse_categorical_accuracy: 0.9863
Epoch 2/5
300/300 [==============================] - 3s 10ms/step - loss: 0.0354 - sparse_categorical_accuracy: 0.9890 - val_loss: 0.0460 - val_sparse_categorical_accuracy: 0.9862
Epoch 3/5
300/300 [==============================] - 3s 10ms/step - loss: 0.0204 - sparse_categorical_accuracy: 0.9930 - val_loss: 0.0522 - val_sparse_categorical_accuracy: 0.9843
Epoch 4/5
300/300 [==============================] - 3s 10ms/step - loss: 0.0121 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.0483 - val_sparse_categorical_accuracy: 0.9866
Epoch 5/5
300/300 [==============================] - 3s 10ms/step - loss: 0.0089 - sparse_categorical_accuracy: 0.9973 - val_loss: 0.0457 - val_sparse_categorical_accuracy: 0.9883
<keras.callbacks.History at 0x7f9d7c13f2e8>

ฝึกโมเดลโดยใช้ลูปการฝึกแบบกำหนดเอง

นอกจากนี้คุณยังสามารถสร้างและการฝึกอบรมรุ่นของคุณโดยใช้ tf.function และ tf.distribute APIs โดยตรง คุณสามารถใช้ strategy.experimental_distribute_datasets_from_function API เพื่อแจกจ่ายชุดข้อมูลที่ได้รับฟังก์ชั่นชุด โปรดทราบว่าในตัวอย่างด้านล่างขนาดชุดงานที่ส่งผ่านไปยังชุดข้อมูลจะเป็นขนาดชุดงานต่อแบบจำลองแทนที่จะเป็นขนาดชุดงานส่วนกลาง ต้องการเรียนรู้เพิ่มเติมให้ตรวจสอบ การฝึกอบรมที่กำหนดเองกับ tf.distribute.Strategy กวดวิชา

ขั้นแรก สร้างโมเดล ชุดข้อมูล และ tf.functions:

# Create the model, optimizer and metrics inside the strategy scope, so that the
# variables can be mirrored on each device.
with strategy.scope():
  model = create_model()
  optimizer = tf.keras.optimizers.Adam()
  training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
  training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      'training_accuracy', dtype=tf.float32)

# Calculate per replica batch size, and distribute the datasets on each TPU
# worker.
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync

train_dataset = strategy.experimental_distribute_datasets_from_function(
    lambda _: get_dataset(per_replica_batch_size, is_training=True))

@tf.function
def train_step(iterator):
  """The step function for one training step."""

  def step_fn(inputs):
    """The computation to run on each TPU device."""
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(images, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits, from_logits=True)
      loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    training_loss.update_state(loss * strategy.num_replicas_in_sync)
    training_accuracy.update_state(labels, logits)

  strategy.run(step_fn, args=(next(iterator),))
WARNING:tensorflow:From <ipython-input-1-5625c2a14441>:15: StrategyBase.experimental_distribute_datasets_from_function (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version.
Instructions for updating:
rename to distribute_datasets_from_function
WARNING:tensorflow:From <ipython-input-1-5625c2a14441>:15: StrategyBase.experimental_distribute_datasets_from_function (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version.
Instructions for updating:
rename to distribute_datasets_from_function

จากนั้นรันลูปการฝึก:

steps_per_eval = 10000 // batch_size

train_iterator = iter(train_dataset)
for epoch in range(5):
  print('Epoch: {}/5'.format(epoch))

  for step in range(steps_per_epoch):
    train_step(train_iterator)
  print('Current step: {}, training loss: {}, accuracy: {}%'.format(
      optimizer.iterations.numpy(),
      round(float(training_loss.result()), 4),
      round(float(training_accuracy.result()) * 100, 2)))
  training_loss.reset_states()
  training_accuracy.reset_states()
Epoch: 0/5
Current step: 300, training loss: 0.2007, accuracy: 94.27%
Epoch: 1/5
Current step: 600, training loss: 0.0374, accuracy: 98.8%
Epoch: 2/5
Current step: 900, training loss: 0.0203, accuracy: 99.35%
Epoch: 3/5
Current step: 1200, training loss: 0.0138, accuracy: 99.57%
Epoch: 4/5
Current step: 1500, training loss: 0.0105, accuracy: 99.66%

การปรับปรุงประสิทธิภาพการทำงานที่มีหลายขั้นตอนภายใน tf.function

คุณสามารถปรับปรุงประสิทธิภาพการทำงานโดยการทำงานหลายขั้นตอนภายใน tf.function นี้จะทำได้โดยการตัด strategy.run โทรที่มี tf.range ภายใน tf.function และลายเซ็นจะแปลงเป็น tf.while_loop ในคนงาน TPU

แม้จะมีการปรับปรุงประสิทธิภาพการทำงานที่มีความสมดุลด้วยวิธีนี้เมื่อเทียบกับการทำงานขั้นตอนเดียวภายใน tf.function ทำงานหลายขั้นตอนใน tf.function มีความยืดหยุ่นน้อยคุณไม่สามารถเรียกใช้สิ่งที่กระหายหรือรหัสหลามพลภายในขั้นตอน

@tf.function
def train_multiple_steps(iterator, steps):
  """The step function for one training step."""

  def step_fn(inputs):
    """The computation to run on each TPU device."""
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(images, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits, from_logits=True)
      loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    training_loss.update_state(loss * strategy.num_replicas_in_sync)
    training_accuracy.update_state(labels, logits)

  for _ in tf.range(steps):
    strategy.run(step_fn, args=(next(iterator),))

# Convert `steps_per_epoch` to `tf.Tensor` so the `tf.function` won't get 
# retraced if the value changes.
train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))

print('Current step: {}, training loss: {}, accuracy: {}%'.format(
      optimizer.iterations.numpy(),
      round(float(training_loss.result()), 4),
      round(float(training_accuracy.result()) * 100, 2)))
Current step: 1800, training loss: 0.0079, accuracy: 99.77%

ขั้นตอนถัดไป