Giúp bảo vệ Great Barrier Reef với TensorFlow trên Kaggle Tham Challenge

Dòng chảy hiệu quả 2

Xem trên TensorFlow.org Chạy trong Google Colab Xem trên GitHub Tải xuống sổ ghi chép

Tổng quat

Hướng dẫn này cung cấp danh sách các phương pháp hay nhất để viết mã bằng TensorFlow 2 (TF2), nó được viết cho những người dùng gần đây đã chuyển sang từ TensorFlow 1 (TF1). Tham khảo các phần di cư của hướng dẫn để biết thêm về di cư đang TF1 của bạn để TF2.

Cài đặt

Nhập TensorFlow và các phần phụ thuộc khác cho các ví dụ trong hướng dẫn này.

import tensorflow as tf
import tensorflow_datasets as tfds

Đề xuất cho TensorFlow 2 thành ngữ

Cấu trúc lại mã của bạn thành các mô-đun nhỏ hơn

Một thực tiễn tốt là cấu trúc lại mã của bạn thành các hàm nhỏ hơn được gọi khi cần thiết. Để đạt hiệu quả tốt nhất, bạn nên cố gắng để trang trí các khối lớn nhất của tính toán rằng bạn có thể trong một tf.function (lưu ý rằng các chức năng python lồng nhau được gọi bằng một tf.function không yêu cầu trang trí riêng của họ, trừ khi bạn muốn sử dụng khác nhau jit_compile cài đặt cho các tf.function ). Tùy thuộc vào trường hợp sử dụng của bạn, đây có thể là nhiều bước đào tạo hoặc thậm chí là toàn bộ vòng đào tạo của bạn. Đối với các trường hợp sử dụng suy luận, nó có thể là một chuyển tiếp mô hình duy nhất.

Điều chỉnh tỷ lệ học mặc định cho một số tf.keras.optimizer s

Một số trình tối ưu hóa Keras có tỷ lệ học tập khác nhau trong TF2. Nếu bạn thấy thay đổi về hành vi hội tụ cho các mô hình của mình, hãy kiểm tra tốc độ học mặc định.

Không có thay đổi cho optimizers.SGD , optimizers.Adam , hoặc optimizers.RMSprop .

Các tỷ lệ học tập mặc định sau đây đã thay đổi:

Sử dụng tf.Module s và các lớp Keras để quản lý các biến

tf.Module s và tf.keras.layers.Layer s phục vụ thuận tiện variablestrainable_variables tài sản, trong đó đệ quy thu thập tất cả các biến phụ thuộc. Điều này giúp dễ dàng quản lý các biến cục bộ tại nơi chúng đang được sử dụng.

Keras lớp / mô hình kế thừa từ tf.train.Checkpointable và được tích hợp với @tf.function , mà làm cho nó có thể trực tiếp hoặc trạm kiểm soát SavedModels xuất khẩu từ các đối tượng Keras. Bạn không nhất thiết phải sử dụng Keras' Model.fit API để tận dụng lợi thế của các tích hợp này.

Đọc phần trên học tập chuyển giao và tinh chỉnh trong hướng dẫn sử Keras để học cách thu thập một tập hợp con của các biến có liên quan sử dụng Keras.

Kết hợp tf.data.Dataset s và tf.function

Các TensorFlow Datasets gói ( tfds ) chứa các tiện ích để tải các tập dữ liệu được xác định trước như tf.data.Dataset đối tượng. Trong ví dụ này, bạn có thể tải các tập dữ liệu MNIST sử dụng tfds :

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

Sau đó, chuẩn bị dữ liệu để đào tạo:

  • Thay đổi tỷ lệ từng hình ảnh.
  • Xáo trộn thứ tự của các ví dụ.
  • Thu thập hàng loạt hình ảnh và nhãn.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5


def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Để giữ cho ví dụ ngắn gọn, hãy cắt bớt tập dữ liệu để chỉ trả về 5 lô:

train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)

STEPS_PER_EPOCH = 5

train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2021-12-08 17:15:01.637157: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Sử dụng lặp lại Python thông thường để lặp qua dữ liệu đào tạo phù hợp với bộ nhớ. Nếu không, tf.data.Dataset là cách tốt nhất để truyền dữ liệu huấn luyện từ đĩa. Datasets là iterables (không lặp) , và làm việc giống như iterables Python khác trong thực hiện háo hức. Bạn hoàn toàn có thể sử dụng dữ liệu async tìm nạp trước / trực tuyến tính năng của gói mã của bạn trong tf.function , thay thế Python lặp với các hoạt động biểu đồ tương đương sử dụng chữ ký.

@tf.function
def train(model, dataset, optimizer):
  for x, y in dataset:
    with tf.GradientTape() as tape:
      # training=True is only needed if there are layers with different
      # behavior during training versus inference (e.g. Dropout).
      prediction = model(x, training=True)
      loss = loss_fn(prediction, y)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

Nếu bạn sử dụng Keras Model.fit API, bạn sẽ không phải lo lắng về dữ liệu lặp.

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)

Sử dụng các vòng huấn luyện Keras

Nếu bạn không cần kiểm soát ở mức độ thấp của quá trình đào tạo của bạn, sử dụng Keras' built-in fit , evaluatepredict phương pháp được khuyến khích. Các phương pháp này cung cấp một giao diện thống nhất để huấn luyện mô hình bất kể việc triển khai (tuần tự, chức năng hay phân lớp).

Ưu điểm của các phương pháp này bao gồm:

  • Họ chấp nhận mảng NumPy, máy phát điện và Python, tf.data.Datasets .
  • Họ áp dụng quy định và các khoản lỗ kích hoạt tự động.
  • Họ hỗ trợ tf.distribute nơi mã đào tạo vẫn giữ nguyên không phụ thuộc vào cấu hình phần cứng .
  • Họ hỗ trợ các khoản có thể gọi tùy ý dưới dạng lỗ và số liệu.
  • Họ hỗ trợ callbacks như tf.keras.callbacks.TensorBoard , và callbacks tùy chỉnh.
  • Chúng hoạt động hiệu quả, tự động sử dụng đồ thị TensorFlow.

Dưới đây là một ví dụ về đào tạo một mô hình sử dụng một Dataset . Để biết chi tiết về cách làm việc này, hãy kiểm tra hướng dẫn .

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)

print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5
5/5 [==============================] - 9s 7ms/step - loss: 1.5762 - accuracy: 0.4938
Epoch 2/5
2021-12-08 17:15:11.145429: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 6ms/step - loss: 0.5087 - accuracy: 0.8969
Epoch 3/5
2021-12-08 17:15:11.559374: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 2s 5ms/step - loss: 0.3348 - accuracy: 0.9469
Epoch 4/5
2021-12-08 17:15:13.860407: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.2445 - accuracy: 0.9688
Epoch 5/5
2021-12-08 17:15:14.269850: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 6ms/step - loss: 0.2006 - accuracy: 0.9719
2021-12-08 17:15:14.717552: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 1s 4ms/step - loss: 1.4553 - accuracy: 0.5781
Loss 1.4552843570709229, Accuracy 0.578125
2021-12-08 17:15:15.862684: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Tùy chỉnh đào tạo và viết vòng lặp của riêng bạn

Nếu các mô hình Keras phù hợp với bạn, nhưng bạn cần sự linh hoạt và kiểm soát tốt hơn đối với bước đào tạo hoặc các vòng đào tạo bên ngoài, bạn có thể thực hiện các bước đào tạo của riêng mình hoặc thậm chí toàn bộ các vòng đào tạo. Xem hướng dẫn Keras trên tùy biến fit để tìm hiểu thêm.

Bạn cũng có thể thực hiện nhiều việc như một tf.keras.callbacks.Callback .

Phương pháp này có nhiều ưu điểm đã đề cập trước , nhưng cung cấp cho bạn kiểm soát các bước đào tạo và thậm chí cả vòng ngoài.

Có ba bước cho một vòng đào tạo tiêu chuẩn:

  1. Lặp trên một máy phát điện Python hoặc tf.data.Dataset để có được lô ví dụ.
  2. Sử dụng tf.GradientTape để gradient thu thập.
  3. Sử dụng một trong những tf.keras.optimizers để áp dụng bản cập nhật trọng lượng cho các biến của mô hình.

Nhớ lại:

  • Luôn luôn bao gồm một training luận trên call phương pháp của lớp subclassed và các mô hình.
  • Hãy chắc chắn để gọi mô hình với training lý luận đặt một cách chính xác.
  • Tùy thuộc vào cách sử dụng, các biến mô hình có thể không tồn tại cho đến khi mô hình được chạy trên một lô dữ liệu.
  • Bạn cần phải xử lý thủ công những thứ như tổn thất quy định cho mô hình.

Không cần phải chạy bộ khởi tạo biến hoặc thêm phần phụ thuộc điều khiển thủ công. tf.function xử lý phụ thuộc điều khiển tự động và khởi tạo biến trên tạo cho bạn.

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

for epoch in range(NUM_EPOCHS):
  for inputs, labels in train_data:
    train_step(inputs, labels)
  print("Finished epoch", epoch)
2021-12-08 17:15:16.714849: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 0
2021-12-08 17:15:17.097043: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 1
2021-12-08 17:15:17.502480: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 2
2021-12-08 17:15:17.873701: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 3
Finished epoch 4
2021-12-08 17:15:18.344196: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Tận dụng lợi thế của tf.function với dòng điều khiển Python

tf.function cung cấp một cách để chuyển đổi dòng điều khiển dữ liệu phụ thuộc vào các khoản tương đương đồ thị dưới chế độ như tf.condtf.while_loop .

Một nơi phổ biến mà luồng điều khiển phụ thuộc vào dữ liệu xuất hiện là trong các mô hình tuần tự. tf.keras.layers.RNN kết thúc tốt đẹp một tế bào RNN, cho phép bạn hoặc là tĩnh hoặc động gỡ hình tái phát. Ví dụ: bạn có thể thực hiện lại động hủy cuộn như sau.

class DynamicRNN(tf.keras.Model):

  def __init__(self, rnn_cell):
    super(DynamicRNN, self).__init__(self)
    self.cell = rnn_cell

  @tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
  def call(self, input_data):

    # [batch, time, features] -> [time, batch, features]
    input_data = tf.transpose(input_data, [1, 0, 2])
    timesteps =  tf.shape(input_data)[0]
    batch_size = tf.shape(input_data)[1]
    outputs = tf.TensorArray(tf.float32, timesteps)
    state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
    for i in tf.range(timesteps):
      output, state = self.cell(input_data[i], state)
      outputs = outputs.write(i, output)
    return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)

my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)

Đọc tf.function hướng dẫn cho một biết thêm thông tin.

Các chỉ số và tổn thất kiểu mới

Metrics và lỗ đều các đối tượng công việc mà hăm hở và trong tf.function s.

Một đối tượng mất là callable, và hy vọng ( y_true , y_pred ) như các đối số:

cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815

Sử dụng các chỉ số để thu thập và hiển thị dữ liệu

Bạn có thể sử dụng tf.metrics để tổng hợp dữ liệu và tf.summary để đăng tóm tắt và chuyển hướng nó vào một nhà văn sử dụng một người quản lý ngữ cảnh. Các tóm tắt được phát ra trực tiếp đến nhà văn có nghĩa là bạn phải cung cấp các step giá trị tại callsite.

summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
  tf.summary.scalar('loss', 0.1, step=42)

Sử dụng tf.metrics để tổng hợp dữ liệu trước khi đăng nhập chúng như tóm tắt. Các chỉ số là trạng thái; họ tích lũy giá trị và trả về một kết quả tích lũy khi bạn gọi result phương pháp (chẳng hạn như Mean.result ). Rõ ràng giá trị với tích lũy Model.reset_states .

def train(model, optimizer, dataset, log_freq=10):
  avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
  for images, labels in dataset:
    loss = train_step(model, optimizer, images, labels)
    avg_loss.update_state(loss)
    if tf.equal(optimizer.iterations % log_freq, 0):
      tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
      avg_loss.reset_states()

def test(model, test_x, test_y, step_num):
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  loss = loss_fn(model(test_x, training=False), test_y)
  tf.summary.scalar('loss', loss, step=step_num)

train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')

with train_summary_writer.as_default():
  train(model, optimizer, dataset)

with test_summary_writer.as_default():
  test(model, test_x, test_y, optimizer.iterations)

Trực quan hóa các bản tóm tắt đã tạo bằng cách trỏ TensorBoard đến thư mục nhật ký tóm tắt:

tensorboard --logdir /tmp/summaries

Sử dụng các tf.summary API để dữ liệu tóm tắt ghi để hiển thị trong TensorBoard. Mọi chi tiết, đọc tf.summary dẫn .

# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  # Update the metrics
  loss_metric.update_state(total_loss)
  accuracy_metric.update_state(labels, predictions)


for epoch in range(NUM_EPOCHS):
  # Reset the metrics
  loss_metric.reset_states()
  accuracy_metric.reset_states()

  for inputs, labels in train_data:
    train_step(inputs, labels)
  # Get the metric results
  mean_loss=loss_metric.result()
  mean_accuracy = accuracy_metric.result()

  print('Epoch: ', epoch)
  print('  loss:     {:.3f}'.format(mean_loss))
  print('  accuracy: {:.3f}'.format(mean_accuracy))
2021-12-08 17:15:19.339736: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  0
  loss:     0.142
  accuracy: 0.991
2021-12-08 17:15:19.781743: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  1
  loss:     0.125
  accuracy: 0.997
2021-12-08 17:15:20.219033: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  2
  loss:     0.110
  accuracy: 0.997
2021-12-08 17:15:20.598085: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  3
  loss:     0.099
  accuracy: 0.997
Epoch:  4
  loss:     0.085
  accuracy: 1.000
2021-12-08 17:15:20.981787: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Tên chỉ số Keras

Các mô hình Keras nhất quán về việc xử lý các tên chỉ số. Khi bạn vượt qua một chuỗi trong danh sách các số liệu, mà chuỗi chính xác được sử dụng như của metric name . Những tên có thể nhìn thấy trong đối tượng lịch sử được trả về bởi model.fit , và trong các bản ghi thông qua với keras.callbacks . được đặt thành chuỗi bạn đã chuyển trong danh sách chỉ số.

model.compile(
    optimizer = tf.keras.optimizers.Adam(0.001),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 5ms/step - loss: 0.0963 - acc: 0.9969 - accuracy: 0.9969 - my_accuracy: 0.9969
2021-12-08 17:15:21.942940: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])

Gỡ lỗi

Sử dụng thực thi háo hức để chạy mã của bạn từng bước để kiểm tra hình dạng, kiểu dữ liệu và giá trị. Một số API, như tf.function , tf.keras , vv được thiết kế để sử dụng Graph thực hiện, cho hiệu suất và tính di động. Khi gỡ lỗi, sử dụng tf.config.run_functions_eagerly(True) để sử dụng thực hiện háo hức bên trong mã này.

Ví dụ:

@tf.function
def f(x):
  if x > 0:
    import pdb
    pdb.set_trace()
    x = x + 1
  return x

tf.config.run_functions_eagerly(True)
f(tf.constant(1))
f()
-> x = x + 1
(Pdb) l
  6     @tf.function
  7     def f(x):
  8       if x > 0:
  9         import pdb
 10         pdb.set_trace()
 11  ->     x = x + 1
 12       return x
 13
 14     tf.config.run_functions_eagerly(True)
 15     f(tf.constant(1))
[EOF]

Điều này cũng hoạt động bên trong các mô hình Keras và các API khác hỗ trợ thực thi mong muốn:

class CustomModel(tf.keras.models.Model):

  @tf.function
  def call(self, input_data):
    if tf.reduce_mean(input_data) > 0:
      return input_data
    else:
      import pdb
      pdb.set_trace()
      return input_data // 2


tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
call()
-> return input_data // 2
(Pdb) l
 10         if tf.reduce_mean(input_data) > 0:
 11           return input_data
 12         else:
 13           import pdb
 14           pdb.set_trace()
 15  ->       return input_data // 2
 16
 17
 18     tf.config.run_functions_eagerly(True)
 19     model = CustomModel()
 20     model(tf.constant([-2, -4]))

Ghi chú:

Đừng giữ tf.Tensors trong đối tượng của bạn

Những đối tượng tensor có thể được tạo ra hoặc trong một tf.function hoặc trong bối cảnh háo hức, và những tensors hành xử khác nhau. Luôn luôn sử dụng tf.Tensor s chỉ cho các giá trị trung gian.

Để theo dõi trạng thái, sử dụng tf.Variable s như họ luôn luôn có thể sử dụng từ cả hai bối cảnh. Đọc tf.Variable hướng dẫn để tìm hiểu thêm.

Tài nguyên và đọc thêm

  • Đọc TF2 hướng dẫnhướng dẫn để tìm hiểu thêm về cách sử dụng TF2.

  • Nếu trước đây bạn đã sử dụng TF1.x, thì bạn nên chuyển mã của mình sang TF2. Đọc di cư hướng dẫn để tìm hiểu thêm.