การเขียนวงจรการฝึกตั้งแต่เริ่มต้น

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

ติดตั้ง

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

บทนำ

Keras ให้การฝึกอบรมและการประเมินผลเริ่มต้นห่วง fit() และ evaluate() การใช้งานของพวกเขาได้รับการคุ้มครองในคู่มือ การฝึกอบรมและการประเมินผลที่มีในตัววิธีการ

หากคุณต้องการที่จะกำหนดขั้นตอนวิธีการเรียนรู้ของรูปแบบของคุณในขณะที่ยังคงใช้ประโยชน์จากความสะดวกสบายของ fit() (ตัวอย่างเช่นในการฝึกอบรม GAN โดยใช้ fit() ) คุณสามารถซับคลาสตัว Model การเรียนและการดำเนินการของคุณเอง train_step() วิธีการซึ่ง เรียกว่าซ้ำ ๆ ระหว่าง fit() นี้ได้รับการคุ้มครองในคู่มือ การปรับแต่งสิ่งที่เกิดขึ้นใน fit()

ตอนนี้ ถ้าคุณต้องการควบคุมการฝึกอบรมและการประเมินในระดับต่ำมาก คุณควรเขียนลูปการฝึกอบรมและการประเมินของคุณเองตั้งแต่ต้น นี่คือสิ่งที่คู่มือนี้เป็นเรื่องเกี่ยวกับ

ใช้ GradientTape : ตัวอย่างแบบ end-to-end แรก

โทรรูปแบบภายในเป็น GradientTape ขอบเขตช่วยให้คุณสามารถเรียกดูการไล่ระดับสีของน้ำหนักสุวินัยของชั้นที่เกี่ยวกับมูลค่าการสูญเสีย ใช้อินสแตนซ์เพิ่มประสิทธิภาพคุณสามารถใช้การไล่ระดับสีเหล่านี้เพื่อปรับปรุงตัวแปรเหล่านี้ (ซึ่งคุณสามารถเรียกใช้ model.trainable_weights )

ลองพิจารณาโมเดล MNIST อย่างง่าย:

inputs = keras.Input(shape=(784,), name="digits")
x1 = layers.Dense(64, activation="relu")(inputs)
x2 = layers.Dense(64, activation="relu")(x1)
outputs = layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)

มาฝึกโดยใช้การไล่ระดับสีแบบกลุ่มย่อยด้วยลูปการฝึกแบบกำหนดเองกัน

อันดับแรก เราต้องการเครื่องมือเพิ่มประสิทธิภาพ ฟังก์ชันการสูญเสีย และชุดข้อมูล:

# Instantiate an optimizer.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 1s 0us/step
11501568/11490434 [==============================] - 1s 0us/step

นี่คือวงการฝึกอบรมของเรา:

  • เราเปิด for วงที่ iterates กว่า epochs
  • สำหรับแต่ละยุคเราเปิด for วงที่ iterates มากกว่าชุดใน batches
  • สำหรับแต่ละชุดเราเปิด GradientTape() ขอบเขต
  • ภายในขอบเขตนี้ เราเรียกโมเดล (forward pass) และคำนวณการสูญเสีย
  • นอกขอบเขต เราดึงการไล่ระดับสีของน้ำหนักของแบบจำลองโดยคำนึงถึงการสูญเสีย
  • สุดท้าย เราใช้เครื่องมือเพิ่มประสิทธิภาพเพื่ออัปเดตน้ำหนักของแบบจำลองตามการไล่ระดับสี
epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):

        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        with tf.GradientTape() as tape:

            # Run the forward pass of the layer.
            # The operations that the layer applies
            # to its inputs are going to be recorded
            # on the GradientTape.
            logits = model(x_batch_train, training=True)  # Logits for this minibatch

            # Compute the loss value for this minibatch.
            loss_value = loss_fn(y_batch_train, logits)

        # Use the gradient tape to automatically retrieve
        # the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, model.trainable_weights)

        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %s samples" % ((step + 1) * batch_size))
Start of epoch 0
Training loss (for one batch) at step 0: 68.7478
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.9448
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.1859
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.6914
Seen so far: 38464 samples

Start of epoch 1
Training loss (for one batch) at step 0: 0.9113
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.9550
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.5139
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.7227
Seen so far: 38464 samples

การจัดการเมตริกในระดับต่ำ

มาเพิ่มการตรวจสอบเมตริกให้กับลูปพื้นฐานนี้

คุณสามารถใช้เมตริกในตัวซ้ำได้ (หรือเมตริกที่คุณเขียนเอง) ในลูปการฝึกอบรมที่เขียนขึ้นใหม่ทั้งหมด นี่คือการไหล:

  • สร้างอินสแตนซ์เมตริกที่จุดเริ่มต้นของลูป
  • โทร metric.update_state() หลังจากที่แต่ละชุด
  • โทร metric.result() เมื่อคุณต้องการที่จะแสดงค่าปัจจุบันของตัวชี้วัด
  • โทร metric.reset_states() เมื่อคุณจำเป็นต้องล้างสถานะของตัวชี้วัด (โดยทั่วไปในตอนท้ายของยุค)

ลองใช้ความรู้นี้ในการคำนวณ SparseCategoricalAccuracy ในการตรวจสอบข้อมูลในตอนท้ายของแต่ละยุค:

# Get model
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)

# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

นี่คือการฝึกอบรมและการประเมินของเรา:

import time

epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y_batch_train, logits)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))
Start of epoch 0
Training loss (for one batch) at step 0: 88.9958
Seen so far: 64 samples
Training loss (for one batch) at step 200: 2.2214
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.3083
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.8282
Seen so far: 38464 samples
Training acc over epoch: 0.7406
Validation acc: 0.8201
Time taken: 6.31s

Start of epoch 1
Training loss (for one batch) at step 0: 0.3276
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.4819
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.5971
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.5862
Seen so far: 38464 samples
Training acc over epoch: 0.8474
Validation acc: 0.8676
Time taken: 5.98s

เร่งขึ้นขั้นตอนการฝึกอบรมของคุณด้วย tf.function

รันไทม์ค่าเริ่มต้นใน TensorFlow ที่ 2 คือ การดำเนินการกระตือรือร้น ด้วยเหตุนี้ วงการฝึกอบรมด้านบนของเราจึงดำเนินการอย่างกระตือรือร้น

นี่เป็นสิ่งที่ดีสำหรับการดีบัก แต่การรวบรวมกราฟมีข้อได้เปรียบด้านประสิทธิภาพที่ชัดเจน การอธิบายการคำนวณของคุณเป็นกราฟคงที่ช่วยให้กรอบงานสามารถใช้การเพิ่มประสิทธิภาพการทำงานทั่วโลกได้ สิ่งนี้เป็นไปไม่ได้เมื่อกรอบงานถูกจำกัดให้ดำเนินการอย่างตะกละตะกลาม โดยไม่รู้ว่าอะไรจะเกิดขึ้นต่อไป

คุณสามารถคอมไพล์เป็นกราฟคงที่ฟังก์ชันใดๆ ก็ตามที่ใช้เมตริกซ์เป็นอินพุต เพียงเพิ่ม @tf.function มัณฑนากรในนั้นเช่นนี้

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

ลองทำเช่นเดียวกันกับขั้นตอนการประเมิน:

@tf.function
def test_step(x, y):
    val_logits = model(x, training=False)
    val_acc_metric.update_state(y, val_logits)

ตอนนี้ ให้เรียกใช้ลูปการฝึกของเราอีกครั้งด้วยขั้นตอนการฝึกที่รวบรวมไว้นี้:

import time

epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))
Start of epoch 0
Training loss (for one batch) at step 0: 0.7921
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.7755
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.1564
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.3181
Seen so far: 38464 samples
Training acc over epoch: 0.8788
Validation acc: 0.8866
Time taken: 1.59s

Start of epoch 1
Training loss (for one batch) at step 0: 0.5222
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.4574
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.4035
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.7561
Seen so far: 38464 samples
Training acc over epoch: 0.8959
Validation acc: 0.9028
Time taken: 1.27s

เร็วกว่ามากใช่ไหม

การจัดการการสูญเสียในระดับต่ำที่ติดตามโดยแบบจำลอง

เลเยอร์และรูปแบบซ้ำติดตามความเสียหายใด ๆ ที่สร้างขึ้นในช่วงที่ผ่านมาโดยชั้นที่โทร self.add_loss(value) รายการที่เกิดจากการสูญเสียค่าสเกลาร์มีอยู่ผ่านทางทรัพย์สิน model.losses ในตอนท้ายของการส่งผ่านไปข้างหน้า

หากคุณต้องการใช้ส่วนประกอบการสูญเสียเหล่านี้ คุณควรรวมและบวกเข้ากับการสูญเสียหลักในขั้นตอนการฝึกของคุณ

พิจารณาเลเยอร์นี้ ที่สร้างการสูญเสียการทำให้เป็นมาตรฐานของกิจกรรม:

class ActivityRegularizationLayer(layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2 * tf.reduce_sum(inputs))
        return inputs

มาสร้างโมเดลง่ายๆ ที่ใช้มันกันเถอะ:

inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = layers.Dense(64, activation="relu")(x)
outputs = layers.Dense(10, name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

นี่คือลักษณะขั้นตอนการฝึกอบรมของเราในตอนนี้:

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
        # Add any extra losses created during the forward pass.
        loss_value += sum(model.losses)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

สรุป

ตอนนี้ คุณรู้ทุกอย่างที่ควรรู้เกี่ยวกับการใช้ลูปการฝึกในตัวและเขียนของคุณเองตั้งแต่เริ่มต้น

โดยสรุป ต่อไปนี้คือตัวอย่างง่ายๆ ตั้งแต่ต้นจนจบซึ่งเชื่อมโยงทุกสิ่งที่คุณได้เรียนรู้ในคู่มือนี้เข้าด้วยกัน: DCGAN ที่ได้รับการฝึกอบรมเกี่ยวกับตัวเลข MNIST

ตัวอย่างตั้งแต่ต้นจนจบ: วงการฝึก GAN ตั้งแต่เริ่มต้น

คุณอาจคุ้นเคยกับ Generative Adversarial Networks (GAN) GAN สามารถสร้างรูปภาพใหม่ที่ดูเหมือนจริง โดยการเรียนรู้การกระจายแฝงของชุดข้อมูลการฝึกอบรมของรูปภาพ ("พื้นที่แฝง" ของรูปภาพ)

GAN ประกอบด้วยสองส่วน: โมเดล "ตัวสร้าง" ที่จับคู่จุดในพื้นที่แฝงไปยังจุดในพื้นที่ภาพ แบบจำลอง "ผู้แยกแยะ" ตัวแยกประเภทที่สามารถบอกความแตกต่างระหว่างภาพจริง (จากชุดข้อมูลการฝึกอบรม) และของปลอม ภาพ (เอาต์พุตของเครือข่ายเครื่องกำเนิด)

วงการฝึก GAN มีลักษณะดังนี้:

1) ฝึกอบรมผู้เลือกปฏิบัติ - สุ่มตัวอย่างชุดของคะแนนสุ่มในพื้นที่แฝง - เปลี่ยนคะแนนเป็นภาพปลอมผ่านโมเดล "เครื่องกำเนิดไฟฟ้า" - รับชุดภาพจริงและรวมเข้ากับภาพที่สร้างขึ้น - ฝึกโมเดล "ผู้เลือกปฏิบัติ" เพื่อจัดประเภทภาพที่สร้างขึ้นเทียบกับภาพจริง

2) ฝึกเครื่องกำเนิดไฟฟ้า - ตัวอย่างจุดสุ่มในพื้นที่แฝง - เปลี่ยนคะแนนเป็นภาพปลอมผ่านเครือข่าย "เครื่องกำเนิดไฟฟ้า" - รับชุดภาพจริงและรวมเข้ากับภาพที่สร้างขึ้น - ฝึกโมเดล "เครื่องกำเนิดไฟฟ้า" เพื่อ "หลอก" ผู้แยกแยะและจำแนกภาพปลอมว่าเป็นของจริง

สำหรับภาพรวมรายละเอียดมากขึ้นของวิธี Gans ผลงานให้ดูที่ การเรียนรู้ลึกกับงูหลาม

ลองใช้วงจรการฝึกอบรมนี้ ขั้นแรก ให้สร้าง discriminator เพื่อจำแนกตัวเลขปลอมกับตัวเลขจริง:

discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)
discriminator.summary()
Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 14, 14, 64)        640       
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 7, 7, 128)         73856     
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 7, 7, 128)         0         
_________________________________________________________________
global_max_pooling2d (Global (None, 128)               0         
_________________________________________________________________
dense_4 (Dense)              (None, 1)                 129       
=================================================================
Total params: 74,625
Trainable params: 74,625
Non-trainable params: 0
_________________________________________________________________

จากนั้นเราจะมาสร้างเครือข่ายเครื่องกำเนิดไฟฟ้าที่จะเปลี่ยนพาหะแฝงเข้าไปในผลของรูปทรง (28, 28, 1) (แทนตัวเลข MNIST):

latent_dim = 128

generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        # We want to generate 128 coefficients to reshape into a 7x7x128 map
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)

นี่คือส่วนสำคัญ: วงการฝึก อย่างที่คุณเห็นมันค่อนข้างตรงไปตรงมา ฟังก์ชันขั้นตอนการฝึกใช้เพียง 17 บรรทัด

# Instantiate one optimizer for the discriminator and another for the generator.
d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)
g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)

# Instantiate a loss function.
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)


@tf.function
def train_step(real_images):
    # Sample random points in the latent space
    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
    # Decode them to fake images
    generated_images = generator(random_latent_vectors)
    # Combine them with real images
    combined_images = tf.concat([generated_images, real_images], axis=0)

    # Assemble labels discriminating real from fake images
    labels = tf.concat(
        [tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0
    )
    # Add random noise to the labels - important trick!
    labels += 0.05 * tf.random.uniform(labels.shape)

    # Train the discriminator
    with tf.GradientTape() as tape:
        predictions = discriminator(combined_images)
        d_loss = loss_fn(labels, predictions)
    grads = tape.gradient(d_loss, discriminator.trainable_weights)
    d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))

    # Sample random points in the latent space
    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))
    # Assemble labels that say "all real images"
    misleading_labels = tf.zeros((batch_size, 1))

    # Train the generator (note that we should *not* update the weights
    # of the discriminator)!
    with tf.GradientTape() as tape:
        predictions = discriminator(generator(random_latent_vectors))
        g_loss = loss_fn(misleading_labels, predictions)
    grads = tape.gradient(g_loss, generator.trainable_weights)
    g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))
    return d_loss, g_loss, generated_images

ลองฝึก GAN ของเราด้วยซ้ำเรียก train_step ในกระบวนการของภาพ

เนื่องจากตัวแบ่งแยกและตัวสร้างของเราเป็นคอนเน็ตต์ คุณจะต้องเรียกใช้โค้ดนี้บน GPU

import os

# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)

epochs = 1  # In practice you need at least 20 epochs to generate nice digits.
save_dir = "./"

for epoch in range(epochs):
    print("\nStart epoch", epoch)

    for step, real_images in enumerate(dataset):
        # Train the discriminator & generator on one batch of real images.
        d_loss, g_loss, generated_images = train_step(real_images)

        # Logging.
        if step % 200 == 0:
            # Print metrics
            print("discriminator loss at step %d: %.2f" % (step, d_loss))
            print("adversarial loss at step %d: %.2f" % (step, g_loss))

            # Save one generated image
            img = tf.keras.preprocessing.image.array_to_img(
                generated_images[0] * 255.0, scale=False
            )
            img.save(os.path.join(save_dir, "generated_img" + str(step) + ".png"))

        # To limit execution time we stop after 10 steps.
        # Remove the lines below to actually train the model!
        if step > 10:
            break
Start epoch 0
discriminator loss at step 0: 0.69
adversarial loss at step 0: 0.69

แค่นั้นแหละ! คุณจะได้รับตัวเลข MNIST ปลอมที่ดูดีหลังจากฝึกฝน Colab GPU เพียงประมาณ 30 วินาที