Сохраните дату! Google I / O возвращается 18-20 мая Зарегистрируйтесь сейчас
Эта страница переведена с помощью Cloud Translation API.
Switch to English

Перенесите свой код TensorFlow 1 на TensorFlow 2

Посмотреть на TensorFlow.org Запускаем в Google Colab Посмотреть исходный код на GitHub Скачать блокнот

Это руководство предназначено для пользователей низкоуровневых API TensorFlow. Если вы используете высокоуровневые API-интерфейсы ( tf.keras ), вам может не потребоваться никаких действий, чтобы сделать ваш код полностью совместимым с TensorFlow 2.x:

По-прежнему можно запускать код 1.x без изменений ( кроме contrib ) в TensorFlow 2.x:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

Однако это не позволяет воспользоваться преимуществами многих улучшений, внесенных в TensorFlow 2.x. Это руководство поможет вам обновить код, сделав его более простым, производительным и легким в обслуживании.

Скрипт автоматической конвертации

Первый шаг, прежде чем пытаться реализовать изменения, описанные в этом руководстве, - это попытаться запустить сценарий обновления .

Это выполнит начальный проход при обновлении вашего кода до TensorFlow 2.x, но не сможет сделать ваш код идиоматическим до v2. Ваш код может по-прежнему использовать tf.compat.v1 точки tf.compat.v1 для доступа к заполнителям, сеансам, коллекциям и другим функциям в стиле 1.x.

Поведенческие изменения верхнего уровня

Если ваш код работает в TensorFlow 2.x с использованием tf.compat.v1.disable_v2_behavior , есть еще глобальные поведенческие изменения, которые вам, возможно, придется tf.compat.v1.disable_v2_behavior . Основные изменения:

  • Стремительное выполнение, v1.enable_eager_execution() : любой код, неявно использующий tf.Graph завершится ошибкой. Обязательно оберните этот код в with tf.Graph().as_default() .

  • Переменные ресурса, v1.enable_resource_variables() : некоторый код может зависеть от недетерминированного поведения, включенного ссылочными переменными TensorFlow. Переменные ресурса блокируются во время записи, что обеспечивает более интуитивно понятные гарантии согласованности.

    • Это может изменить поведение в крайних случаях.
    • Это может привести к созданию дополнительных копий и увеличению использования памяти.
    • Это можно отключить, передав use_resource=False конструктору tf.Variable .
  • Тензорные формы, v1.enable_v2_tensorshape() : TensorFlow 2.x упрощает поведение тензорных фигур. Вместо t.shape[0].value можно сказать t.shape[0] . Эти изменения должны быть небольшими, и есть смысл сразу их исправить. Примеры см. В разделе TensorShape .

  • Поток управления, v1.enable_control_flow_v2() : реализация потока управления в v1.enable_control_flow_v2() 2.x была упрощена, поэтому создаются различные представления графов. Пожалуйста, сообщайте об ошибках по любым вопросам.

Создайте код для TensorFlow 2.x

В этом руководстве будет рассмотрено несколько примеров преобразования кода TensorFlow 1.x в TensorFlow 2.x. Эти изменения позволят вашему коду воспользоваться преимуществами оптимизации производительности и упрощенных вызовов API.

В каждом случае шаблон следующий:

1. Заменить вызовы v1.Session.run

Каждый вызов v1.Session.run следует заменить функцией Python.

  • feed_dict и v1.placeholder становятся аргументами функции.
  • fetches становятся возвращаемым значением функции.
  • Во время преобразования активное выполнение позволяет легко отлаживать стандартные инструменты Python, такие как pdb .

После этого добавьте декоратор tf.function чтобы он эффективно работал на графике. Ознакомьтесь с руководством по автографу, чтобы узнать больше о том, как это работает.

Обратите внимание, что:

  • В отличие от v1.Session.run , tf.function имеет фиксированную подпись возврата и всегда возвращает все выходные данные. Если это вызывает проблемы с производительностью, создайте две отдельные функции.

  • Нет необходимости в tf.control_dependencies или аналогичных операциях: tf.function ведет себя так, как если бы она была запущена в указанном порядке. tf.Variable assignments и tf.assert s выполняются автоматически.

Раздел моделей преобразования содержит рабочий пример этого процесса преобразования.

2. Используйте объекты Python для отслеживания переменных и потерь.

В TensorFlow 2.x категорически не рекомендуется отслеживание переменных на основе имен. Используйте объекты Python для отслеживания переменных.

Используйте tf.Variable вместо v1.get_variable .

Каждый v1.variable_scope должен быть преобразован в объект Python. Обычно это будет одно из:

Если вам нужно объединить списки переменных (например, tf.Graph.get_collection(tf.GraphKeys.VARIABLES) ), используйте .variables и .trainable_variables объектов Layer и Model .

Эти классы Layer и Model реализуют несколько других свойств, которые устраняют необходимость в глобальных коллекциях. Их свойство .losses может заменить использование коллекции tf.GraphKeys.LOSSES .

Обратитесь к руководству по Keras для получения более подробной информации.

3. Обновите свои тренировочные циклы.

Используйте API самого высокого уровня, который подходит для вашего варианта использования. Предпочитайте tf.keras.Model.fit построению собственных тренировочных циклов.

Эти высокоуровневые функции управляют множеством низкоуровневых деталей, которые можно легко упустить, если вы напишете свой собственный цикл обучения. Например, они автоматически собирают потери регуляризации и устанавливают аргумент training=True при вызове модели.

4. Обновите конвейеры ввода данных.

Используйте tf.data данных tf.data для ввода данных. Эти объекты эффективны, выразительны и хорошо интегрируются с tenorflow.

Их можно передать непосредственно в метод tf.keras.Model.fit .

model.fit(dataset, epochs=5)

Их можно повторять напрямую через стандартный Python:

for example_batch, label_batch in dataset:
    break

5. compat.v1 символы compat.v1

Модуль tf.compat.v1 содержит полный API TensorFlow 1.x с его исходной семантикой.

Скрипт обновления TensorFlow 2.x преобразует символы в их эквиваленты v2, если такое преобразование является безопасным, т. v1.arg_max Если он может определить, что поведение версии TensorFlow 2.x точно эквивалентно (например, он переименует v1.arg_max в tf.argmax , поскольку это tf.argmax и та же функция).

После того, как сценарий обновления выполнен с помощью фрагмента кода, вероятно, будет много упоминаний о compat.v1 . Стоит просмотреть код и вручную преобразовать их в эквивалент v2 (если он есть, это следует указать в журнале).

Конвертация моделей

Низкоуровневые переменные и выполнение оператора

Примеры использования низкоуровневого API:

  • Использование переменных областей видимости для управления повторным использованием.
  • Создание переменных с помощью v1.get_variable .
  • Явный доступ к коллекциям.
  • Неявный доступ к коллекциям с помощью таких методов, как:

  • Использование v1.placeholder для настройки входных данных графа.

  • Выполнение графиков с помощью Session.run .

  • Инициализация переменных вручную.

Перед преобразованием

Вот как эти шаблоны могут выглядеть в коде с использованием TensorFlow 1.x.

import tensorflow as tf
import tensorflow.compat.v1 as v1

import tensorflow_datasets as tfds
g = v1.Graph()

with g.as_default():
  in_a = v1.placeholder(dtype=v1.float32, shape=(2))
  in_b = v1.placeholder(dtype=v1.float32, shape=(2))

  def forward(x):
    with v1.variable_scope("matmul", reuse=v1.AUTO_REUSE):
      W = v1.get_variable("W", initializer=v1.ones(shape=(2,2)),
                          regularizer=lambda x:tf.reduce_mean(x**2))
      b = v1.get_variable("b", initializer=v1.zeros(shape=(2)))
      return W * x + b

  out_a = forward(in_a)
  out_b = forward(in_b)
  reg_loss=v1.losses.get_regularization_loss(scope="matmul")

with v1.Session(graph=g) as sess:
  sess.run(v1.global_variables_initializer())
  outs = sess.run([out_a, out_b, reg_loss],
                feed_dict={in_a: [1, 0], in_b: [0, 1]})

print(outs[0])
print()
print(outs[1])
print()
print(outs[2])
[[1. 0.]
 [1. 0.]]

[[0. 1.]
 [0. 1.]]

1.0

После преобразования

В преобразованном коде:

  • Переменные - это локальные объекты Python.
  • Функция forward прежнему определяет расчет.
  • Вызов Session.run заменяется вызовом для forward .
  • Дополнительный декоратор tf.function может быть добавлен для повышения производительности.
  • Регуляризации рассчитываются вручную, без привязки к какой-либо глобальной коллекции.
  • Здесь не используются сеансы или заполнители .
W = tf.Variable(tf.ones(shape=(2,2)), name="W")
b = tf.Variable(tf.zeros(shape=(2)), name="b")

@tf.function
def forward(x):
  return W * x + b

out_a = forward([1,0])
print(out_a)
tf.Tensor(
[[1. 0.]
 [1. 0.]], shape=(2, 2), dtype=float32)
out_b = forward([0,1])

regularizer = tf.keras.regularizers.l2(0.04)
reg_loss=regularizer(W)

Модели на основе tf.layers

Модуль v1.layers используется для хранения функций уровня, которые полагаются на v1.variable_scope для определения и повторного использования переменных.

Перед преобразованием

def model(x, training, scope='model'):
  with v1.variable_scope(scope, reuse=v1.AUTO_REUSE):
    x = v1.layers.conv2d(x, 32, 3, activation=v1.nn.relu,
          kernel_regularizer=lambda x:0.004*tf.reduce_mean(x**2))
    x = v1.layers.max_pooling2d(x, (2, 2), 1)
    x = v1.layers.flatten(x)
    x = v1.layers.dropout(x, 0.1, training=training)
    x = v1.layers.dense(x, 64, activation=v1.nn.relu)
    x = v1.layers.batch_normalization(x, training=training)
    x = v1.layers.dense(x, 10)
    return x
train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))

train_out = model(train_data, training=True)
test_out = model(test_data, training=False)

print(train_out)
print()
print(test_out)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/convolutional.py:414: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  warnings.warn('`tf.layers.conv2d` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:2273: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead.
  warnings.warn('`layer.apply` is deprecated and '
tf.Tensor([[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(1, 10), dtype=float32)

tf.Tensor(
[[ 0.379358   -0.55901194  0.48704922  0.11619566  0.23902717  0.01691487
   0.07227738  0.14556988  0.2459927   0.2501198 ]], shape=(1, 10), dtype=float32)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/pooling.py:310: UserWarning: `tf.layers.max_pooling2d` is deprecated and will be removed in a future version. Please use `tf.keras.layers.MaxPooling2D` instead.
  warnings.warn('`tf.layers.max_pooling2d` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:329: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  warnings.warn('`tf.layers.flatten` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:268: UserWarning: `tf.layers.dropout` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dropout` instead.
  warnings.warn('`tf.layers.dropout` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:171: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  warnings.warn('`tf.layers.dense` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/normalization.py:308: UserWarning: `tf.layers.batch_normalization` is deprecated and will be removed in a future version. Please use `tf.keras.layers.BatchNormalization` instead. In particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not be used (consult the `tf.keras.layers.BatchNormalization` documentation).
  '`tf.layers.batch_normalization` is deprecated and '

После преобразования

Большинство аргументов осталось прежним. Но обратите внимание на различия:

  • Аргумент training передается модели на каждый уровень при ее запуске.
  • Первый аргумент функции исходной model (вход x ) пропал. Это связано с тем, что слои объектов отделяют построение модели от вызова модели.

Также обратите внимание, что:

  • Если вы используете регуляризаторы или инициализаторы из tf.contrib , у них больше изменений аргументов, чем у других.
  • Код больше не записывается в коллекции, поэтому такие функции, как v1.losses.get_regularization_loss больше не будут возвращать эти значения, что может нарушить ваши циклы обучения.
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.04),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))
train_out = model(train_data, training=True)
print(train_out)
tf.Tensor([[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(1, 10), dtype=float32)
test_out = model(test_data, training=False)
print(test_out)
tf.Tensor(
[[-0.2145557  -0.22979769 -0.14968733  0.01208701 -0.07569927  0.3475932
   0.10718458  0.03482988 -0.04309493 -0.10469118]], shape=(1, 10), dtype=float32)
# Here are all the trainable variables
len(model.trainable_variables)
8
# Here is the regularization loss
model.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.08174552>]

Смешанные переменные и v1.layers

В существующем коде часто смешиваются переменные TensorFlow 1.x нижнего уровня и операции с уровнями v1.layers более высокого уровня.

Перед преобразованием

def model(x, training, scope='model'):
  with v1.variable_scope(scope, reuse=v1.AUTO_REUSE):
    W = v1.get_variable(
      "W", dtype=v1.float32,
      initializer=v1.ones(shape=x.shape),
      regularizer=lambda x:0.004*tf.reduce_mean(x**2),
      trainable=True)
    if training:
      x = x + W
    else:
      x = x + W * 0.5
    x = v1.layers.conv2d(x, 32, 3, activation=tf.nn.relu)
    x = v1.layers.max_pooling2d(x, (2, 2), 1)
    x = v1.layers.flatten(x)
    return x

train_out = model(train_data, training=True)
test_out = model(test_data, training=False)

После преобразования

Чтобы преобразовать этот код, следуйте шаблону сопоставления слоев со слоями, как в предыдущем примере.

Общая схема такова:

  • Собираем параметры слоя в __init__ .
  • Соберите переменные в build .
  • Выполните вычисления в call и верните результат.

v1.variable_scope - это, по сути, отдельный слой. Так что tf.keras.layers.Layer его как tf.keras.layers.Layer . Ознакомьтесь с руководством « Создание новых слоев и моделей с помощью подклассов» .

# Create a custom layer for part of the model
class CustomLayer(tf.keras.layers.Layer):
  def __init__(self, *args, **kwargs):
    super(CustomLayer, self).__init__(*args, **kwargs)

  def build(self, input_shape):
    self.w = self.add_weight(
        shape=input_shape[1:],
        dtype=tf.float32,
        initializer=tf.keras.initializers.ones(),
        regularizer=tf.keras.regularizers.l2(0.02),
        trainable=True)

  # Call method will sometimes get used in graph mode,
  # training will get turned into a tensor
  @tf.function
  def call(self, inputs, training=None):
    if training:
      return inputs + self.w
    else:
      return inputs + self.w * 0.5
custom_layer = CustomLayer()
print(custom_layer([1]).numpy())
print(custom_layer([1], training=True).numpy())
[1.5]
[2.]
train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))

# Build the model including the custom layer
model = tf.keras.Sequential([
    CustomLayer(input_shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(32, 3, activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
])

train_out = model(train_data, training=True)
test_out = model(test_data, training=False)

Несколько замечаний:

  • Подклассы моделей и слоев Keras должны работать как в графах v1 (без зависимостей автоматического управления), так и в режиме ожидания:

    • Оберните call в tf.function чтобы получить автограф и зависимости автоматического управления.
  • Не забудьте принять аргумент training для call :

    • Иногда это tf.Tensor
    • Иногда это логическое значение Python
  • Создайте переменные модели в конструкторе или Model.build используя self.add_weight:

    • В Model.build вас есть доступ к входной форме, поэтому вы можете создавать веса с соответствующей формой
    • Использование tf.keras.layers.Layer.add_weight позволяет tf.keras.layers.Layer.add_weight отслеживать переменные и потери регуляризации.
  • Не храните tf.Tensors в своих объектах:

    • Они могут быть созданы либо в tf.function либо в контексте tf.function , и эти тензоры ведут себя по-разному.
    • Используйте tf.Variable для состояния, они всегда доступны в обоих контекстах.
    • tf.Tensors только для промежуточных значений

Заметка о Slim и contrib.layers

Большая часть старого кода TensorFlow 1.x использует библиотеку Slim , которая была упакована с TensorFlow 1.x как tf.contrib.layers . Как модуль contrib , он больше не доступен в TensorFlow 2.x, даже в tf.compat.v1 . Преобразование кода с использованием Slim в TensorFlow 2.x сложнее, чем преобразование репозиториев, использующих v1.layers . Фактически, может иметь смысл сначала преобразовать ваш Slim-код в v1.layers , а затем преобразовать в Keras.

  • Удалите arg_scopes , все аргументы должны быть явными.
  • Если вы их используете, разделите normalizer_fn и activation_fn на их собственные слои.
  • Разделимые сверточные слои сопоставляются с одним или несколькими различными слоями Keras (глубинными, точечными и разделяемыми слоями Keras).
  • Slim и v1.layers имеют разные имена аргументов и значения по умолчанию.
  • Некоторые аргументы имеют разный масштаб.
  • Если вы используете предварительно обученные модели Slim, попробуйте предварительно обученные модели tf.keras.applications из tf.keras.applications или TensorFlow 2.x SavedModels от TF Hub, экспортированные из исходного кода Slim.

Некоторые уровни tf.contrib могли не быть перемещены в ядро ​​TensorFlow, а вместо этого были перемещены в пакет tf.contrib .

Обучение

Есть много способов tf.keras данные в модель tf.keras . Они будут принимать генераторы Python и массивы Numpy в качестве входных данных.

Рекомендуемый способ передачи данных в модель - использовать пакет tf.data , который содержит набор высокопроизводительных классов для управления данными.

Если вы все еще используете tf.queue , теперь они поддерживаются только как структуры данных, а не как входные конвейеры.

Использование наборов данных TensorFlow

TensorFlow Datasets пакет ( tfds ) содержит утилиты для загрузки предопределенных наборов данных в качествеtf.data.Dataset объектов.

В этом примере вы можете загрузить набор данных MNIST с помощью tfds :

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1...
WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your
local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.
Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

Затем подготовьте данные для обучения:

  • Измените масштаб каждого изображения.
  • Перемешайте примеры.
  • Собирайте партии изображений и этикеток.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5


def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

Чтобы пример был кратким, обрежьте набор данных, чтобы получить только 5 пакетов:

train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)

STEPS_PER_EPOCH = 5

train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))

Используйте обучающие циклы Keras

Если вам не нужен низкоуровневый контроль вашего тренировочного процесса, рекомендуется использовать встроенные в predict методы fit , evaluate и predict . Эти методы предоставляют единый интерфейс для обучения модели независимо от реализации (последовательной, функциональной или подклассовой).

К преимуществам этих методов можно отнести:

  • Они принимают массивы Numpy, генераторы Python и tf.data.Datasets .
  • Они автоматически применяют регуляризацию и потери активации.
  • Они поддерживают tf.distribute для tf.distribute на нескольких устройствах .
  • Они поддерживают произвольные вызовы в виде потерь и показателей.
  • Они поддерживают обратные вызовы, такие как tf.keras.callbacks.TensorBoard , и настраиваемые обратные вызовы.
  • Они производительны, автоматически используя графики TensorFlow.

Вот пример обучения модели с использованием Dataset . (Подробнее о том, как это работает, читайте в разделе руководств .)

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)

print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5
5/5 [==============================] - 1s 9ms/step - loss: 2.0191 - accuracy: 0.3608
Epoch 2/5
5/5 [==============================] - 0s 9ms/step - loss: 0.4736 - accuracy: 0.9059
Epoch 3/5
5/5 [==============================] - 0s 8ms/step - loss: 0.2973 - accuracy: 0.9626
Epoch 4/5
5/5 [==============================] - 0s 9ms/step - loss: 0.2108 - accuracy: 0.9911
Epoch 5/5
5/5 [==============================] - 0s 8ms/step - loss: 0.1791 - accuracy: 0.9874
5/5 [==============================] - 0s 6ms/step - loss: 1.5504 - accuracy: 0.7500
Loss 1.5504140853881836, Accuracy 0.75

Напишите свой собственный цикл

Если шаг обучения модели tf.keras.Model.train_on_batch работает для вас, но вам нужен больший контроль за пределами этого шага, рассмотрите возможность использования метода tf.keras.Model.train_on_batch в вашем собственном цикле итерации данных.

Помните: многие вещи можно реализовать как tf.keras.callbacks.Callback .

Этот метод обладает многими преимуществами методов, упомянутых в предыдущем разделе, но дает пользователю возможность управлять внешним циклом.

Вы также можете использовать tf.keras.Model.test_on_batch или tf.keras.Model.evaluate для проверки производительности во время обучения.

Чтобы продолжить обучение указанной выше модели:

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

for epoch in range(NUM_EPOCHS):
  # Reset the metric accumulators
  model.reset_metrics()

  for image_batch, label_batch in train_data:
    result = model.train_on_batch(image_batch, label_batch)
    metrics_names = model.metrics_names
    print("train: ",
          "{}: {:.3f}".format(metrics_names[0], result[0]),
          "{}: {:.3f}".format(metrics_names[1], result[1]))
  for image_batch, label_batch in test_data:
    result = model.test_on_batch(image_batch, label_batch,
                                 # Return accumulated metrics
                                 reset_metrics=False)
  metrics_names = model.metrics_names
  print("\neval: ",
        "{}: {:.3f}".format(metrics_names[0], result[0]),
        "{}: {:.3f}".format(metrics_names[1], result[1]))
train:  loss: 0.138 accuracy: 1.000
train:  loss: 0.161 accuracy: 1.000
train:  loss: 0.159 accuracy: 0.969
train:  loss: 0.241 accuracy: 0.953
train:  loss: 0.172 accuracy: 0.969

eval:  loss: 1.550 accuracy: 0.800
train:  loss: 0.086 accuracy: 1.000
train:  loss: 0.094 accuracy: 1.000
train:  loss: 0.090 accuracy: 1.000
train:  loss: 0.119 accuracy: 0.984
train:  loss: 0.099 accuracy: 1.000

eval:  loss: 1.558 accuracy: 0.841
train:  loss: 0.076 accuracy: 1.000
train:  loss: 0.068 accuracy: 1.000
train:  loss: 0.061 accuracy: 1.000
train:  loss: 0.076 accuracy: 1.000
train:  loss: 0.076 accuracy: 1.000

eval:  loss: 1.536 accuracy: 0.841
train:  loss: 0.059 accuracy: 1.000
train:  loss: 0.056 accuracy: 1.000
train:  loss: 0.058 accuracy: 1.000
train:  loss: 0.054 accuracy: 1.000
train:  loss: 0.055 accuracy: 1.000

eval:  loss: 1.497 accuracy: 0.863
train:  loss: 0.053 accuracy: 1.000
train:  loss: 0.049 accuracy: 1.000
train:  loss: 0.044 accuracy: 1.000
train:  loss: 0.049 accuracy: 1.000
train:  loss: 0.045 accuracy: 1.000

eval:  loss: 1.463 accuracy: 0.878

Настройте шаг обучения

Если вам нужна большая гибкость и контроль, вы можете получить это, реализовав свой собственный цикл обучения. Есть три шага:

  1. tf.data.Dataset генератор Python илиtf.data.Dataset чтобы получить партии примеров.
  2. Используйтеtf.GradientTape для сбора градиентов.
  3. Используйте один из tf.keras.optimizers чтобы применить обновления веса к переменным модели.

Помнить:

  • Всегда включайте аргумент training в метод call подклассов слоев и моделей.
  • Убедитесь, что вы вызываете модель с правильно заданным аргументом training .
  • В зависимости от использования переменные модели могут не существовать, пока модель не будет запущена на пакете данных.
  • Вам нужно вручную обрабатывать такие вещи, как потери регуляризации для модели.

Обратите внимание на упрощения относительно v1:

  • Нет необходимости запускать инициализаторы переменных. Переменные инициализируются при создании.
  • Нет необходимости добавлять зависимости ручного управления. Даже в tf.function операции действуют как в режиме ожидания.
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

for epoch in range(NUM_EPOCHS):
  for inputs, labels in train_data:
    train_step(inputs, labels)
  print("Finished epoch", epoch)
Finished epoch 0
Finished epoch 1
Finished epoch 2
Finished epoch 3
Finished epoch 4

Метрики и потери в новом стиле

В TensorFlow 2.x показатели и потери являются объектами. Они работают как с готовностью, так и с tf.function .

Объект потерь является вызываемым и ожидает (y_true, y_pred) в качестве аргументов:

cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815

Метрический объект имеет следующие методы:

  • Metric.update_state() : добавить новые наблюдения.
  • Metric.result() : получить текущий результат метрики с учетом наблюдаемых значений.
  • Metric.reset_states() : очистить все наблюдения.

Сам объект является вызываемым. Вызов обновляет состояние с новыми наблюдениями, как в случае с update_state , и возвращает новый результат метрики.

Вам не нужно вручную инициализировать переменные метрики, и поскольку TensorFlow 2.x имеет зависимости автоматического управления, вам также не нужно беспокоиться о них.

В приведенном ниже коде используется метрика для отслеживания средней потери, наблюдаемой в пользовательском цикле обучения.

# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  # Update the metrics
  loss_metric.update_state(total_loss)
  accuracy_metric.update_state(labels, predictions)


for epoch in range(NUM_EPOCHS):
  # Reset the metrics
  loss_metric.reset_states()
  accuracy_metric.reset_states()

  for inputs, labels in train_data:
    train_step(inputs, labels)
  # Get the metric results
  mean_loss=loss_metric.result()
  mean_accuracy = accuracy_metric.result()

  print('Epoch: ', epoch)
  print('  loss:     {:.3f}'.format(mean_loss))
  print('  accuracy: {:.3f}'.format(mean_accuracy))
Epoch:  0
  loss:     0.139
  accuracy: 0.997
Epoch:  1
  loss:     0.116
  accuracy: 1.000
Epoch:  2
  loss:     0.105
  accuracy: 0.997
Epoch:  3
  loss:     0.089
  accuracy: 1.000
Epoch:  4
  loss:     0.078
  accuracy: 1.000

Названия метрик Keras

В TensorFlow 2.x модели Keras более последовательны в обработке имен метрик.

Теперь , когда вы передаете строку в списке показателей, что точная строка используется в качестве метрики по name . Эти имена видны в объекте истории, возвращаемом model.fit , и в журналах, переданных в keras.callbacks . устанавливается в строку, которую вы передали в списке показателей.

model.compile(
    optimizer = tf.keras.optimizers.Adam(0.001),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 8ms/step - loss: 0.0901 - acc: 0.9923 - accuracy: 0.9923 - my_accuracy: 0.9923
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])

Это отличается от предыдущих версий, в которых передача metrics=["accuracy"] dict_keys(['loss', 'acc']) metrics=["accuracy"] приводила к dict_keys(['loss', 'acc'])

Оптимизаторы Keras

Оптимизаторы в v1.train , такие как v1.train.AdamOptimizer и v1.train.GradientDescentOptimizer , имеют эквиваленты в tf.keras.optimizers .

Преобразование v1.train в keras.optimizers

При преобразовании оптимизаторов следует помнить следующее:

Новые настройки по умолчанию для некоторых tf.keras.optimizers

Нет изменений для optimizers.SGD , optimizers.Adam или optimizers.RMSprop .

Изменились следующие скорости обучения по умолчанию:

TensorBoard

TensorFlow 2.x включает значительные изменения в API tf.summary используется для записи сводных данных для визуализации в TensorBoard. Для общего введения в новый tf.summary доступно несколько руководств , в которых используется TensorFlow 2.x API. Сюда входит руководство по миграции TensorBoard TensorFlow 2.x.

Сохранение и загрузка

Совместимость контрольных точек

TensorFlow 2.x использует объектно-ориентированные контрольные точки .

Контрольные точки на основе имен в старом стиле все еще могут быть загружены, если вы будете осторожны. Процесс преобразования кода может привести к изменению имени переменной, но есть обходные пути.

Самый простой подход - совместить названия новой модели с названиями в контрольной точке:

  • Все переменные по-прежнему имеют аргумент name вы можете установить.
  • Модели Keras также принимают аргумент name который они устанавливают в качестве префикса для своих переменных.
  • v1.name_scope можно использовать для установки префиксов имен переменных. Это сильно отличается от tf.variable_scope . Он влияет только на имена и не отслеживает переменные и повторное использование.

Если в вашем случае это не работает, попробуйте функцию v1.train.init_from_checkpoint . Он принимает аргумент assignment_map , который определяет отображение старых имен на новые.

Репозиторий TensorFlow Estimator включает инструмент преобразования для обновления контрольных точек для предварительно созданных оценщиков с TensorFlow 1.x до 2.0. Это может служить примером того, как создать инструмент для аналогичного варианта использования.

Совместимость сохраненных моделей

Для сохраненных моделей нет серьезных проблем с совместимостью.

  • Сохраненные_модели TensorFlow 1.x работают в TensorFlow 2.x.
  • Сохраненные_модели TensorFlow 2.x работают в TensorFlow 1.x, если поддерживаются все операции.

Graph.pb или Graph.pbtxt

Нет простого способа обновить необработанный файл Graph.pb до TensorFlow 2.x. Лучше всего обновить код, создавший файл.

Но если у вас есть «замороженный график» ( tf.Graph где переменные преобразованы в константы), то можно преобразовать его в concrete_function v1.wrap_function с помощью v1.wrap_function :

def wrap_frozen_graph(graph_def, inputs, outputs):
  def _imports_graph_def():
    tf.compat.v1.import_graph_def(graph_def, name="")
  wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
  import_graph = wrapped_import.graph
  return wrapped_import.prune(
      tf.nest.map_structure(import_graph.as_graph_element, inputs),
      tf.nest.map_structure(import_graph.as_graph_element, outputs))

Например, вот замороженный график для Inception v1 от 2016 года:

path = tf.keras.utils.get_file(
    'inception_v1_2016_08_28_frozen.pb',
    'http://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz',
    untar=True)
Downloading data from http://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz
24698880/24695710 [==============================] - 1s 0us/step

Загрузите tf.GraphDef :

graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(open(path,'rb').read())

Оберните это в concrete_function :

inception_func = wrap_frozen_graph(
    graph_def, inputs='input:0',
    outputs='InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu:0')

Передайте ему тензор в качестве входных данных:

input_img = tf.ones([1,224,224,3], dtype=tf.float32)
inception_func(input_img).shape
TensorShape([1, 28, 28, 96])

Оценщики

Обучение с оценщиками

Оценщики поддерживаются в TensorFlow 2.x.

Когда вы используете оценщики, вы можете использовать input_fn , tf.estimator.TrainSpec и tf.estimator.EvalSpec из tf.estimator.EvalSpec 1.x.

Вот пример использования input_fn с обучением и оценкой спецификаций.

Создание спецификаций input_fn и train / eval

# Define the estimator's input_fn
def input_fn():
  datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000
  BATCH_SIZE = 64

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label[..., tf.newaxis]

  train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  return train_data.repeat()

# Define train and eval specs
train_spec = tf.estimator.TrainSpec(input_fn=input_fn,
                                    max_steps=STEPS_PER_EPOCH * NUM_EPOCHS)
eval_spec = tf.estimator.EvalSpec(input_fn=input_fn,
                                  steps=STEPS_PER_EPOCH)

Использование определения модели Keras

Есть некоторые отличия в том, как создавать ваши оценщики в TensorFlow 2.x.

Рекомендуется определить модель с помощью Keras, а затем использовать утилиту tf.keras.estimator.model_to_estimator чтобы превратить вашу модель в оценщик. В приведенном ниже коде показано, как использовать эту утилиту при создании и обучении оценщика.

def make_model():
  return tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
  ])
model = make_model()

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

estimator = tf.keras.estimator.model_to_estimator(
  keras_model = model
)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp0erq3im2
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp0erq3im2
INFO:tensorflow:Using the Keras model provided.
INFO:tensorflow:Using the Keras model provided.
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/backend.py:434: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp0erq3im2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp0erq3im2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp0erq3im2/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp0erq3im2/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmp/tmp0erq3im2/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmp/tmp0erq3im2/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 8 variables.
INFO:tensorflow:Warm-started 8 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp0erq3im2/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp0erq3im2/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.4717796, step = 0
INFO:tensorflow:loss = 2.4717796, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmp0erq3im2/model.ckpt.
INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmp0erq3im2/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:2325: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`Model.state_updates` will be removed in a future version. '
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:17Z
INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:17Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp0erq3im2/model.ckpt-25
INFO:tensorflow:Restoring parameters from /tmp/tmp0erq3im2/model.ckpt-25
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Inference Time : 0.86556s
INFO:tensorflow:Inference Time : 0.86556s
INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:18
INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:18
INFO:tensorflow:Saving dict for global step 25: accuracy = 0.6, global_step = 25, loss = 1.6160676
INFO:tensorflow:Saving dict for global step 25: accuracy = 0.6, global_step = 25, loss = 1.6160676
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmp0erq3im2/model.ckpt-25
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmp0erq3im2/model.ckpt-25
INFO:tensorflow:Loss for final step: 0.37597787.
INFO:tensorflow:Loss for final step: 0.37597787.
({'accuracy': 0.6, 'loss': 1.6160676, 'global_step': 25}, [])

Использование пользовательской model_fn

Если у вас есть существующий пользовательский оценщик model_fn который вам необходимо поддерживать, вы можете преобразовать свой model_fn для использования модели model_fn .

Однако из соображений совместимости пользовательский model_fn прежнему будет работать в графическом режиме в стиле 1.x. Это означает, что нет активного исполнения и зависимостей автоматического управления.

Пользовательский model_fn с минимальными изменениями

Чтобы ваша настраиваемая model_fn работала в TensorFlow 2.x, если вы предпочитаете минимальные изменения в существующем коде, tf.compat.v1 символы tf.compat.v1 такие как optimizers и metrics .

Использование модели model_fn в пользовательской model_fn аналогично ее использованию в пользовательском цикле обучения:

  • Установите подходящую фазу training зависимости от аргумента mode .
  • Явно передайте оптимизатору trainable_variables модели.

Но есть важные отличия относительно пользовательского цикла :

  • Вместо использования Model.losses извлеките убытки с помощью Model.get_losses_for .
  • Извлеките обновления модели с помощью Model.get_updates_for .

Следующий код создает оценщик из пользовательской model_fn , иллюстрируя все эти проблемы.

def my_model_fn(features, labels, mode):
  model = make_model()

  optimizer = tf.compat.v1.train.AdamOptimizer()
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

  training = (mode == tf.estimator.ModeKeys.TRAIN)
  predictions = model(features, training=training)

  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

  reg_losses = model.get_losses_for(None) + model.get_losses_for(features)
  total_loss=loss_fn(labels, predictions) + tf.math.add_n(reg_losses)

  accuracy = tf.compat.v1.metrics.accuracy(labels=labels,
                                           predictions=tf.math.argmax(predictions, axis=1),
                                           name='acc_op')

  update_ops = model.get_updates_for(None) + model.get_updates_for(features)
  minimize_op = optimizer.minimize(
      total_loss,
      var_list=model.trainable_variables,
      global_step=tf.compat.v1.train.get_or_create_global_step())
  train_op = tf.group(minimize_op, update_ops)

  return tf.estimator.EstimatorSpec(
    mode=mode,
    predictions=predictions,
    loss=total_loss,
    train_op=train_op, eval_metric_ops={'accuracy': accuracy})

# Create the Estimator & Train
estimator = tf.estimator.Estimator(model_fn=my_model_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpifj8mysl
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpifj8mysl
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpifj8mysl', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpifj8mysl', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpifj8mysl/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpifj8mysl/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 3.0136237, step = 0
INFO:tensorflow:loss = 3.0136237, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmpifj8mysl/model.ckpt.
INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmpifj8mysl/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:20Z
INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:20Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpifj8mysl/model.ckpt-25
INFO:tensorflow:Restoring parameters from /tmp/tmpifj8mysl/model.ckpt-25
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Inference Time : 0.97406s
INFO:tensorflow:Inference Time : 0.97406s
INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:21
INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:21
INFO:tensorflow:Saving dict for global step 25: accuracy = 0.59375, global_step = 25, loss = 1.6248872
INFO:tensorflow:Saving dict for global step 25: accuracy = 0.59375, global_step = 25, loss = 1.6248872
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmpifj8mysl/model.ckpt-25
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmpifj8mysl/model.ckpt-25
INFO:tensorflow:Loss for final step: 0.35726172.
INFO:tensorflow:Loss for final step: 0.35726172.
({'accuracy': 0.59375, 'loss': 1.6248872, 'global_step': 25}, [])

Пользовательский model_fn с символами model_fn 2.x

Если вы хотите избавиться от всех символов model_fn 1.x и обновить свой собственный model_fn до model_fn 2.x, вам необходимо обновить оптимизатор и метрики до tf.keras.optimizers и tf.keras.metrics .

В пользовательской model_fn , помимо указанных выше изменений , необходимо выполнить дополнительные обновления:

  • Используйте tf.keras.optimizers вместо v1.train.Optimizer .
  • Явно передайте trainable_variables модели в tf.keras.optimizers .
  • Чтобы вычислить train_op/minimize_op ,
    • Используйте Optimizer.get_updates если потеря - это скалярный Tensor потерь (не вызываемый). Первый элемент в возвращаемом списке - это желаемый train_op/minimize_op .
    • Если потеря вызывается (например, функцией), используйте Optimizer.minimize чтобы получить train_op/minimize_op .
  • Для оценки используйте tf.keras.metrics вместо tf.compat.v1.metrics .

В приведенном выше примере my_model_fn перенесенный код с символами TensorFlow 2.x показан как:

def my_model_fn(features, labels, mode):
  model = make_model()

  training = (mode == tf.estimator.ModeKeys.TRAIN)
  loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  predictions = model(features, training=training)

  # Get both the unconditional losses (the None part)
  # and the input-conditional losses (the features part).
  reg_losses = model.get_losses_for(None) + model.get_losses_for(features)
  total_loss=loss_obj(labels, predictions) + tf.math.add_n(reg_losses)

  # Upgrade to tf.keras.metrics.
  accuracy_obj = tf.keras.metrics.Accuracy(name='acc_obj')
  accuracy = accuracy_obj.update_state(
      y_true=labels, y_pred=tf.math.argmax(predictions, axis=1))

  train_op = None
  if training:
    # Upgrade to tf.keras.optimizers.
    optimizer = tf.keras.optimizers.Adam()
    # Manually assign tf.compat.v1.global_step variable to optimizer.iterations
    # to make tf.compat.v1.train.global_step increased correctly.
    # This assignment is a must for any `tf.train.SessionRunHook` specified in
    # estimator, as SessionRunHooks rely on global step.
    optimizer.iterations = tf.compat.v1.train.get_or_create_global_step()
    # Get both the unconditional updates (the None part)
    # and the input-conditional updates (the features part).
    update_ops = model.get_updates_for(None) + model.get_updates_for(features)
    # Compute the minimize_op.
    minimize_op = optimizer.get_updates(
        total_loss,
        model.trainable_variables)[0]
    train_op = tf.group(minimize_op, *update_ops)

  return tf.estimator.EstimatorSpec(
    mode=mode,
    predictions=predictions,
    loss=total_loss,
    train_op=train_op,
    eval_metric_ops={'Accuracy': accuracy_obj})

# Create the Estimator and train.
estimator = tf.estimator.Estimator(model_fn=my_model_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpc93qfnv6
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpc93qfnv6
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpc93qfnv6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpc93qfnv6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpc93qfnv6/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpc93qfnv6/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.5293791, step = 0
INFO:tensorflow:loss = 2.5293791, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmpc93qfnv6/model.ckpt.
INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmpc93qfnv6/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:24Z
INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:24Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpc93qfnv6/model.ckpt-25
INFO:tensorflow:Restoring parameters from /tmp/tmpc93qfnv6/model.ckpt-25
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Inference Time : 0.86534s
INFO:tensorflow:Inference Time : 0.86534s
INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:25
INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:25
INFO:tensorflow:Saving dict for global step 25: Accuracy = 0.59375, global_step = 25, loss = 1.7570661
INFO:tensorflow:Saving dict for global step 25: Accuracy = 0.59375, global_step = 25, loss = 1.7570661
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmpc93qfnv6/model.ckpt-25
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmpc93qfnv6/model.ckpt-25
INFO:tensorflow:Loss for final step: 0.47094986.
INFO:tensorflow:Loss for final step: 0.47094986.
({'Accuracy': 0.59375, 'loss': 1.7570661, 'global_step': 25}, [])

Готовые оценщики

Готовые оценщики из семейства tf.estimator.DNN* , tf.estimator.Linear* и tf.estimator.DNNLinearCombined* по-прежнему поддерживаются в TensorFlow 2.x API. Однако некоторые аргументы изменились:

  1. input_layer_partitioner : удалено в v2.
  2. loss_reduction : обновлено до tf.keras.losses.Reduction вместо tf.compat.v1.losses.Reduction . Его значение по умолчанию также изменяется на tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE из tf.compat.v1.losses.Reduction.SUM .
  3. optimizer , dnn_optimizer и linear_optimizer : этот аргумент был обновлен до tf.keras.optimizers вместо tf.compat.v1.train.Optimizer .

Чтобы перенести указанные выше изменения:

  1. Для input_layer_partitioner миграция не требуется, поскольку Distribution Strategy будет обрабатывать ее автоматически в TensorFlow 2.x.
  2. Для loss_reduction проверьте tf.keras.losses.Reduction чтобы tf.keras.losses.Reduction о поддерживаемых параметрах.
  3. Для аргументов optimizer :
    • Если вы не: 1) передаете аргумент optimizer , dnn_optimizer или linear_optimizer , или 2) указываете аргумент optimizer как string в вашем коде, вам не нужно ничего менять, потому что tf.keras.optimizers используется по умолчанию .
    • В противном случае вам необходимо обновить его с tf.compat.v1.train.Optimizer до соответствующего tf.keras.optimizers .

Конвертер контрольных точек

Переход на keras.optimizers нарушит контрольные точки, сохраненные с помощью TensorFlow 1.x, поскольку tf.keras.optimizers генерирует другой набор переменных, которые будут сохранены в контрольных точках. Чтобы сделать старую контрольную точку многоразовой после перехода на TensorFlow 2.x, попробуйте инструмент конвертера контрольных точек .

 curl -O https://raw.githubusercontent.com/tensorflow/estimator/master/tensorflow_estimator/python/estimator/tools/checkpoint_converter.py
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 15165  100 15165    0     0  40656      0 --:--:-- --:--:-- --:--:-- 40656

Инструмент имеет встроенную справку:

 python checkpoint_converter.py -h
2021-01-06 02:31:26.297951: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
usage: checkpoint_converter.py [-h]
                               {dnn,linear,combined} source_checkpoint
                               source_graph target_checkpoint

positional arguments:
  {dnn,linear,combined}
                        The type of estimator to be converted. So far, the
                        checkpoint converter only supports Canned Estimator.
                        So the allowed types include linear, dnn and combined.
  source_checkpoint     Path to source checkpoint file to be read in.
  source_graph          Path to source graph file to be read in.
  target_checkpoint     Path to checkpoint file to be written out.

optional arguments:
  -h, --help            show this help message and exit

TensorShape

Этот класс был упрощен для хранения int вместо объектов tf.compat.v1.Dimension . Таким образом, нет необходимости вызывать .value чтобы получить int .

Отдельные объекты tf.compat.v1.Dimension по-прежнему доступны из tf.TensorShape.dims .

Ниже показаны различия между TensorFlow 1.x и TensorFlow 2.x.

# Create a shape and choose an index
i = 0
shape = tf.TensorShape([16, None, 256])
shape
TensorShape([16, None, 256])

Если у вас это было в TensorFlow 1.x:

value = shape[i].value

Затем сделайте это в TensorFlow 2.x:

value = shape[i]
value
16

Если у вас это было в TensorFlow 1.x:

for dim in shape:
    value = dim.value
    print(value)

Затем сделайте это в TensorFlow 2.x:

for value in shape:
  print(value)
16
None
256

Если у вас это было в TensorFlow 1.x (или вы использовали любой другой метод измерения):

dim = shape[i]
dim.assert_is_compatible_with(other_dim)

Затем сделайте это в TensorFlow 2.x:

other_dim = 16
Dimension = tf.compat.v1.Dimension

if shape.rank is None:
  dim = Dimension(None)
else:
  dim = shape.dims[i]
dim.is_compatible_with(other_dim) # or any other dimension method
True
shape = tf.TensorShape(None)

if shape:
  dim = shape.dims[i]
  dim.is_compatible_with(other_dim) # or any other dimension method

tf.TensorShape значение tf.TensorShape равно True если ранг известен, и False противном случае.

print(bool(tf.TensorShape([])))      # Scalar
print(bool(tf.TensorShape([0])))     # 0-length vector
print(bool(tf.TensorShape([1])))     # 1-length vector
print(bool(tf.TensorShape([None])))  # Unknown-length vector
print(bool(tf.TensorShape([1, 10, 100])))       # 3D tensor
print(bool(tf.TensorShape([None, None, None]))) # 3D tensor with no known dimensions
print()
print(bool(tf.TensorShape(None)))  # A tensor with unknown rank.
True
True
True
True
True
True

False

Прочие изменения

  • Удалить tf.colocate_with : алгоритмы размещения устройств в tf.colocate_with значительно улучшились. В этом больше не должно быть необходимости. Если его удаление приводит к снижению производительности, сообщите об ошибке .

  • Замените использование v1.ConfigProto эквивалентными функциями из tf.config .

Выводы

Общий процесс:

  1. Запустите сценарий обновления.
  2. Удалите символы contrib.
  3. Переключите свои модели на объектно-ориентированный стиль (Keras).
  4. По возможности используйте tf.keras или tf.estimator и циклы оценки.
  5. В противном случае используйте пользовательские циклы, но избегайте сеансов и коллекций.

Преобразование кода в идиоматический TensorFlow 2.x требует небольшой работы, но каждое изменение приводит к:

  • Меньше строк кода.
  • Повышенная ясность и простота.
  • Более легкая отладка.