การฝึกอบรมแบบกำหนดเองด้วย tf.distribute.Strategy

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

บทช่วยสอนนี้สาธิตวิธีใช้ tf.distribute.Strategy พร้อมลูปการฝึกแบบกำหนดเอง เราจะฝึกโมเดล CNN อย่างง่ายบนชุดข้อมูล MNIST ของแฟชั่น ชุดข้อมูลแฟชั่น MNIST ประกอบด้วยรูปภาพรถไฟ 60000 รูปขนาด 28 x 28 และรูปภาพทดสอบ 10000 รูปขนาด 28 x 28

เรากำลังใช้ลูปการฝึกแบบกำหนดเองเพื่อฝึกโมเดลของเรา เนื่องจากช่วยให้มีความยืดหยุ่นและควบคุมการฝึกได้มากขึ้น ยิ่งไปกว่านั้น การดีบักโมเดลและลูปการฝึกทำได้ง่ายขึ้น

# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

print(tf.__version__)
2.8.0-rc1

ดาวน์โหลดชุดข้อมูลแฟชั่น MNIST

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Adding a dimension to the array -> new shape == (28, 28, 1)
# We are doing this because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Getting the images in [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)

สร้างกลยุทธ์เพื่อกระจายตัวแปรและกราฟ

กลยุทธ์ tf.distribute.MirroredStrategy ทำงานอย่างไร

  • ตัวแปรทั้งหมดและกราฟแบบจำลองถูกจำลองแบบบนแบบจำลอง
  • อินพุตถูกกระจายอย่างเท่าเทียมกันทั่วทั้งแบบจำลอง
  • แบบจำลองแต่ละรายการจะคำนวณการสูญเสียและการไล่ระดับสีสำหรับอินพุตที่ได้รับ
  • การไล่ระดับสีจะซิงค์กับแบบจำลองทั้งหมดโดยการรวมเข้าด้วยกัน
  • หลังจากการซิงค์ จะมีการอัพเดทแบบเดียวกันกับสำเนาของตัวแปรในแต่ละเรพพลิกา
# If the list of devices is not specified in the
# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

ตั้งค่าไปป์ไลน์อินพุต

ส่งออกกราฟและตัวแปรไปยังรูปแบบ SavedModel ที่ไม่เชื่อเรื่องพระเจ้าบนแพลตฟอร์ม หลังจากที่แบบจำลองของคุณได้รับการบันทึกแล้ว คุณสามารถโหลดแบบจำลองโดยมีหรือไม่มีขอบเขตก็ได้

BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

สร้างชุดข้อมูลและแจกจ่าย:

train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
2022-01-26 05:45:53.991501: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_UINT8
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 60000
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:0"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
        dim {
          size: 1
        }
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
}

2022-01-26 05:45:54.034762: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_FLOAT
      type: DT_UINT8
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 10000
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:3"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: 28
        }
        dim {
          size: 28
        }
        dim {
          size: 1
        }
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_FLOAT
        }
      }
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_UINT8
        }
      }
    }
  }
}

สร้างแบบจำลอง

สร้างโมเดลโดยใช้ tf.keras.Sequential คุณยังสามารถใช้ Model Subclassing API เพื่อทำสิ่งนี้ได้

def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
    ])

  return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

กำหนดฟังก์ชันการสูญเสีย

โดยปกติ ในเครื่องเดียวที่มี 1 GPU/CPU การสูญเสียจะถูกหารด้วยจำนวนตัวอย่างในชุดอินพุต

ดังนั้นควรคำนวณการสูญเสียอย่างไรเมื่อใช้ tf.distribute.Strategy ?

  • ตัวอย่างเช่น สมมติว่าคุณมี 4 GPU และขนาดแบทช์ 64 อินพุตหนึ่งชุดถูกแจกจ่ายทั่วทั้งแบบจำลอง (4 GPU) แต่ละแบบจำลองจะได้รับอินพุตขนาด 16

  • โมเดลในแต่ละแบบจำลองจะส่งต่อด้วยอินพุตที่เกี่ยวข้องและคำนวณการสูญเสีย ตอนนี้ แทนที่จะหารการสูญเสียด้วยจำนวนตัวอย่างในอินพุตที่เกี่ยวข้อง (BATCH_SIZE_PER_REPLICA = 16) การขาดทุนควรหารด้วย GLOBAL_BATCH_SIZE (64)

ทำไมทำเช่นนี้?

  • จำเป็นต้องทำสิ่งนี้เพราะหลังจากคำนวณการไล่ระดับสีในแต่ละแบบจำลองแล้ว จะซิงค์ข้ามแบบจำลองโดยการรวม เข้า ด้วยกัน

จะทำสิ่งนี้ใน TensorFlow ได้อย่างไร?

  • หากคุณกำลังเขียนลูปการฝึกแบบกำหนดเอง เช่นเดียวกับในบทช่วยสอนนี้ คุณควรรวมการสูญเสียต่อตัวอย่างและหารผลรวมด้วย GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE) หรือคุณสามารถใช้ tf.nn.compute_average_loss ซึ่งรับการสูญเสียต่อตัวอย่าง น้ำหนักตัวอย่างที่เลือกได้ และ GLOBAL_BATCH_SIZE เป็นอาร์กิวเมนต์ และส่งกลับค่าการสูญเสียตามสัดส่วน

  • หากคุณกำลังใช้การสูญเสียการทำให้เป็นมาตรฐานในแบบจำลองของคุณ คุณจะต้องปรับขนาดค่าการสูญเสียตามจำนวนแบบจำลอง คุณสามารถทำได้โดยใช้ฟังก์ชัน tf.nn.scale_regularization_loss

  • ไม่แนะนำให้ใช้ tf.reduce_mean การทำเช่นนี้จะแบ่งการสูญเสียตามจริงต่อขนาดแบทช์แบบจำลอง ซึ่งอาจแตกต่างกันไปทีละขั้นตอน

  • การย่อและปรับขนาดนี้ทำได้โดยอัตโนมัติใน keras model.compile และ model.fit

  • หากใช้คลาส tf.keras.losses (ดังในตัวอย่างด้านล่าง) การลดการสูญเสียจะต้องระบุอย่างชัดเจนให้เป็นหนึ่งใน NONE หรือ SUM AUTO และ SUM_OVER_BATCH_SIZE ไม่อนุญาตเมื่อใช้กับ tf.distribute.Strategy ไม่อนุญาตให้ใช้ AUTO เนื่องจากผู้ใช้ควรคิดอย่างชัดแจ้งเกี่ยวกับการลดหย่อนที่ต้องการเพื่อให้แน่ใจว่าถูกต้องในกรณีที่แจกจ่าย ไม่อนุญาต SUM_OVER_BATCH_SIZE เนื่องจากในปัจจุบันจะหารด้วยขนาดแบทช์แบบจำลองเท่านั้น และปล่อยให้ผู้ใช้หารด้วยจำนวนแบบจำลอง ซึ่งอาจพลาดได้ง่าย ดังนั้นเราจึงขอให้ผู้ใช้ทำการลดลงอย่างชัดเจนแทน

  • หาก labels เป็นแบบหลายมิติ ให้เฉลี่ย per_example_loss จากจำนวนองค์ประกอบในแต่ละตัวอย่าง ตัวอย่างเช่น หากรูปแบบการ predictions คือ (batch_size, H, W, n_classes) และ labels คือ (batch_size, H, W) คุณจะต้องอัปเดต per_example_loss เช่น: per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)

with strategy.scope():
  # Set reduction to `none` so we can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)

กำหนดตัวชี้วัดเพื่อติดตามการสูญเสียและความแม่นยำ

ตัวชี้วัดเหล่านี้ติดตามการสูญเสียการทดสอบและการฝึกอบรมและความแม่นยำในการทดสอบ คุณสามารถใช้ . .result() เพื่อรับสถิติสะสมได้ตลอดเวลา

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

วงการฝึก

# model, optimizer, and checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss 

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in test_dist_dataset:
    distributed_test_step(x)

  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print (template.format(epoch+1, train_loss,
                         train_accuracy.result()*100, test_loss.result(),
                         test_accuracy.result()*100))

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()
Epoch 1, Loss: 0.5106383562088013, Accuracy: 81.77999877929688, Test Loss: 0.39399346709251404, Test Accuracy: 85.79000091552734
Epoch 2, Loss: 0.3362727463245392, Accuracy: 87.91333770751953, Test Loss: 0.35871225595474243, Test Accuracy: 86.7699966430664
Epoch 3, Loss: 0.2928692400455475, Accuracy: 89.2683334350586, Test Loss: 0.2999486029148102, Test Accuracy: 89.04000091552734
Epoch 4, Loss: 0.2605818510055542, Accuracy: 90.41999816894531, Test Loss: 0.28474125266075134, Test Accuracy: 89.47000122070312
Epoch 5, Loss: 0.23641237616539001, Accuracy: 91.32166290283203, Test Loss: 0.26421546936035156, Test Accuracy: 90.41000366210938
Epoch 6, Loss: 0.2192477434873581, Accuracy: 91.90499877929688, Test Loss: 0.2650589942932129, Test Accuracy: 90.4800033569336
Epoch 7, Loss: 0.20016911625862122, Accuracy: 92.66999816894531, Test Loss: 0.25025954842567444, Test Accuracy: 90.9000015258789
Epoch 8, Loss: 0.18381091952323914, Accuracy: 93.26499938964844, Test Loss: 0.2585820257663727, Test Accuracy: 90.95999908447266
Epoch 9, Loss: 0.1699329912662506, Accuracy: 93.67500305175781, Test Loss: 0.26234227418899536, Test Accuracy: 91.0199966430664
Epoch 10, Loss: 0.15756534039974213, Accuracy: 94.16333770751953, Test Loss: 0.25516414642333984, Test Accuracy: 90.93000030517578

สิ่งที่ควรทราบในตัวอย่างข้างต้น:

  • เรากำลังวนซ้ำบน train_dist_dataset และ test_dist_dataset โดยใช้ a for x in ... construct
  • การสูญเสียตามมาตราส่วนคือค่าส่งคืนของ distributed_train_step ค่านี้ถูกรวมระหว่างแบบจำลองโดยใช้การเรียก tf.distribute.Strategy.reduce จากนั้นข้ามแบทช์โดยการรวมค่าที่ส่งคืนของการเรียก tf.distribute.Strategy.reduce
  • tf.keras.Metrics ควรได้รับการอัปเดตภายใน train_step และ test_step ที่ดำเนินการโดย tf.distribute.Strategy.run * tf.distribute.Strategy.run ส่งคืนผลลัพธ์จากการจำลองแบบโลคัลแต่ละรายการในกลยุทธ์ และมีหลายวิธีในการใช้ผลลัพธ์นี้ คุณสามารถทำ tf.distribute.Strategy.reduce เพื่อรับมูลค่ารวม คุณยังสามารถทำ tf.distribute.Strategy.experimental_local_results เพื่อรับรายการค่าที่มีอยู่ในผลลัพธ์ หนึ่งค่าต่อแบบจำลองในเครื่อง

คืนค่าจุดตรวจล่าสุดและทดสอบ

โมเดลที่มีจุดตรวจสอบด้วย tf.distribute.Strategy สามารถกู้คืนได้โดยมีหรือไม่มีกลยุทธ์

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
  eval_step(images, labels)

print ('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result()*100))
Accuracy after restoring the saved model without strategy: 91.0199966430664

วิธีอื่นในการวนซ้ำชุดข้อมูล

การใช้ตัววนซ้ำ

หากคุณต้องการวนซ้ำตามจำนวนขั้นตอนที่กำหนด และไม่ผ่านชุดข้อมูลทั้งหมด คุณสามารถสร้างตัววนซ้ำได้โดยใช้การเรียก iter และการเรียกแบบชัดแจ้ง next บนตัววนซ้ำ คุณสามารถเลือกที่จะวนซ้ำชุดข้อมูลทั้งภายในและภายนอก tf.function นี่คือตัวอย่างเล็กๆ ที่สาธิตการวนซ้ำของชุดข้อมูลนอก tf.function โดยใช้ตัววนซ้ำ

for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))
  train_accuracy.reset_states()
Epoch 10, Loss: 0.17486707866191864, Accuracy: 93.4375
Epoch 10, Loss: 0.12386945635080338, Accuracy: 95.3125
Epoch 10, Loss: 0.16411852836608887, Accuracy: 93.90625
Epoch 10, Loss: 0.10728752613067627, Accuracy: 96.40625
Epoch 10, Loss: 0.11865834891796112, Accuracy: 95.625
Epoch 10, Loss: 0.12875251471996307, Accuracy: 95.15625
Epoch 10, Loss: 0.1189488023519516, Accuracy: 95.625
Epoch 10, Loss: 0.1456708014011383, Accuracy: 95.15625
Epoch 10, Loss: 0.12446556240320206, Accuracy: 95.3125
Epoch 10, Loss: 0.1380888819694519, Accuracy: 95.46875

วนซ้ำภายใน tf.function

คุณยังสามารถวนซ้ำอินพุต train_dist_dataset ทั้งหมดภายใน tf.function โดยใช้ for x in ... construct หรือโดยการสร้าง iterators เหมือนที่เราทำด้านบน ตัวอย่างด้านล่างแสดงให้เห็นถึงการห่อหนึ่งช่วงของการฝึกใน tf.function และวนซ้ำบน train_dist_dataset ภายในฟังก์ชัน

@tf.function
def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses = strategy.run(train_step, args=(x,))
    total_loss += strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print (template.format(epoch+1, train_loss, train_accuracy.result()*100))

  train_accuracy.reset_states()
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py:449: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
Epoch 1, Loss: 0.14398494362831116, Accuracy: 94.63999938964844
Epoch 2, Loss: 0.13246288895606995, Accuracy: 94.97333526611328
Epoch 3, Loss: 0.11922841519117355, Accuracy: 95.63833618164062
Epoch 4, Loss: 0.11084160208702087, Accuracy: 95.99333190917969
Epoch 5, Loss: 0.10420522093772888, Accuracy: 96.0816650390625
Epoch 6, Loss: 0.09215126931667328, Accuracy: 96.63500213623047
Epoch 7, Loss: 0.0878651961684227, Accuracy: 96.67666625976562
Epoch 8, Loss: 0.07854588329792023, Accuracy: 97.09333038330078
Epoch 9, Loss: 0.07217177003622055, Accuracy: 97.34833526611328
Epoch 10, Loss: 0.06753655523061752, Accuracy: 97.48999786376953

ติดตามการสูญเสียการฝึกอบรมข้ามแบบจำลอง

เรา ไม่ แนะนำให้ใช้ tf.metrics.Mean เพื่อติดตามการสูญเสียการฝึกในแบบจำลองต่างๆ เนื่องจากการคำนวณมาตราส่วนการสูญเสียที่ดำเนินการ

ตัวอย่างเช่น หากคุณเรียกใช้งานการฝึกอบรมที่มีลักษณะดังต่อไปนี้:

  • สองแบบจำลอง
  • สองตัวอย่างจะถูกประมวลผลในแต่ละแบบจำลอง
  • ค่าการสูญเสียผลลัพธ์: [2, 3] และ [4, 5] ในแต่ละแบบจำลอง
  • ขนาดแบทช์ทั่วโลก = 4

ด้วยมาตราส่วนการสูญเสีย คุณจะคำนวณมูลค่าต่อตัวอย่างของการสูญเสียในแต่ละแบบจำลองโดยการเพิ่มค่าการสูญเสีย แล้วหารด้วยขนาดชุดงานส่วนกลาง ในกรณีนี้: (2 + 3) / 4 = 1.25 และ (4 + 5) / 4 = 2.25 .

หากคุณใช้ tf.metrics.Mean เพื่อติดตามการสูญเสียในแบบจำลองทั้งสอง ผลลัพธ์จะแตกต่างกัน ในตัวอย่างนี้ คุณลงเอยด้วยผล total 3.50 และ count เป็น 2 ซึ่งส่งผลให้ผล total / count = 1.75 เมื่อมีการเรียก result() บนเมตริก การสูญเสียที่คำนวณด้วย tf.keras.Metrics จะถูกปรับขนาดโดยปัจจัยเพิ่มเติมที่เท่ากับจำนวนของแบบจำลองที่ซิงค์

คู่มือและตัวอย่าง

ต่อไปนี้คือตัวอย่างบางส่วนสำหรับการใช้กลยุทธ์การกระจายกับลูปการฝึกแบบกำหนดเอง:

  1. คู่มือการอบรมแบบกระจาย
  2. ตัวอย่าง DenseNet โดยใช้ MirroredStrategy
  3. ตัวอย่าง BERT ที่ฝึกโดยใช้ MirroredStrategy และ TPUStrategy ตัวอย่างนี้มีประโยชน์อย่างยิ่งในการทำความเข้าใจวิธีการโหลดจากจุดตรวจ และสร้างจุดตรวจเป็นระยะระหว่างการฝึกอบรมแบบกระจาย ฯลฯ
  4. ตัวอย่าง NCF ที่ฝึกโดยใช้ MirroredStrategy ที่สามารถเปิดใช้งานได้โดยใช้แฟ keras_use_ctl
  5. ตัวอย่าง NMT ฝึกโดยใช้ MirroredStrategy

ตัวอย่างเพิ่มเติมที่ระบุไว้ใน คู่มือกลยุทธ์การจัดจำหน่าย

ขั้นตอนถัดไป

  • ลองใช้ tf.distribute.Strategy API ใหม่กับโมเดลของคุณ
  • เยี่ยมชม ส่วนประสิทธิภาพ ในคู่มือเพื่อเรียนรู้เพิ่มเติมเกี่ยวกับกลยุทธ์และ เครื่องมือ อื่นๆ ที่คุณสามารถใช้เพื่อเพิ่มประสิทธิภาพการทำงานของแบบจำลอง TensorFlow ของคุณ