¿Tengo una pregunta? Conéctese con la comunidad en el Foro de visita del foro de TensorFlow

Estimadores

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Este documento presenta tf.estimator una API de TensorFlow de alto nivel. Los estimadores encapsulan las siguientes acciones:

  • Capacitación
  • Evaluación
  • Predicción
  • Exportar para servir

TensorFlow implementa varios Estimadores prediseñados. Los estimadores personalizados todavía se admiten, pero principalmente como una medida de compatibilidad con versiones anteriores. Los estimadores personalizados no deben usarse para código nuevo . Todos los estimadores, prefabricados o personalizados, son clases basadas en la clase tf.estimator.Estimator .

Para ver un ejemplo rápido, pruebe los tutoriales de Estimator . Para obtener una descripción general del diseño de la API, consulte el documento técnico .

Configuración

pip install -U tensorflow_datasets
import tempfile
import os

import tensorflow as tf
import tensorflow_datasets as tfds

Ventajas

Similar a tf.keras.Model , un estimator es una abstracción a nivel de modelo. El tf.estimator proporciona algunas capacidades actualmente en desarrollo para tf.keras . Estos son:

  • Entrenamiento basado en servidor de parámetros
  • Integración TFX completa

Capacidades de los estimadores

Los estimadores brindan los siguientes beneficios:

  • Puede ejecutar modelos basados ​​en Estimator en un host local o en un entorno distribuido de varios servidores sin cambiar su modelo. Además, puede ejecutar modelos basados ​​en Estimator en CPU, GPU o TPU sin recodificar su modelo.
  • Los estimadores proporcionan un ciclo de entrenamiento distribuido seguro que controla cómo y cuándo:
    • Cargar datos
    • Manejar excepciones
    • Cree archivos de puntos de control y recupérese de fallas
    • Guardar resúmenes para TensorBoard

Al escribir una aplicación con Estimadores, debe separar la canalización de entrada de datos del modelo. Esta separación simplifica los experimentos con diferentes conjuntos de datos.

Usar estimadores prediseñados

Los Estimadores prediseñados te permiten trabajar a un nivel conceptual mucho más alto que las API básicas de TensorFlow. Ya no tiene que preocuparse por crear el gráfico o las sesiones computacionales, ya que los estimadores se encargan de toda la "plomería" por usted. Además, los Estimadores prediseñados le permiten experimentar con diferentes arquitecturas de modelos al realizar solo cambios mínimos en el código. tf.estimator.DNNClassifier , por ejemplo, es una clase de Estimador prediseñada que entrena modelos de clasificación basados ​​en redes neuronales densas y de avance.

Un programa de TensorFlow que se basa en un Estimador prediseñado generalmente consta de los siguientes cuatro pasos:

1. Escribe una función de entrada

Por ejemplo, puede crear una función para importar el conjunto de entrenamiento y otra función para importar el conjunto de prueba. Los estimadores esperan que sus entradas tengan el formato de un par de objetos:

  • Un diccionario en el que las claves son nombres de funciones y los valores son tensores (o sensores dispersos) que contienen los datos de funciones correspondientes.
  • Un tensor que contiene una o más etiquetas

input_fn debería devolver untf.data.Dataset que produzca pares en ese formato.

Por ejemplo, el siguiente código crea untf.data.Dataset partir del archivo train.csv del conjunto de datos Titanic:

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.AUTOTUNE))
  return titanic_batches

input_fn se ejecuta en un tf.Graph y también puede devolver directamente un par (features_dics, labels) contiene tensores de gráfico, pero esto es propenso a errores fuera de casos simples como devolver constantes.

2. Defina las columnas de características.

Cada tf.feature_column identifica un nombre de característica, su tipo y cualquier preprocesamiento de entrada.

Por ejemplo, el siguiente fragmento crea tres columnas de funciones.

  • El primero usa la característica de age directamente como una entrada de punto flotante.
  • El segundo usa la característica de class como entrada categórica.
  • El tercero usa embark_town como entrada categórica, pero usa el hashing trick para evitar la necesidad de enumerar las opciones y establecer el número de opciones.

Para obtener más información, consulte el tutorial de columnas de funciones .

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3. Cree una instancia del Estimador prefabricado relevante.

Por ejemplo, aquí hay una instancia de ejemplo de un Estimador LinearClassifier llamado LinearClassifier :

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpeqzx9get', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Para obtener más información, puede ir al tutorial del clasificador lineal .

4. Llame a un método de capacitación, evaluación o inferencia.

Todos los estimadores proporcionan métodos de train , evaluate y predict .

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1700: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  warnings.warn('`layer.add_variable` is deprecated and '
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:149: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpeqzx9get/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpeqzx9get/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.54946315.
result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-06-19T01:21:21
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpeqzx9get/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.63381s
INFO:tensorflow:Finished evaluation at 2021-06-19-01:21:21
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.646875, accuracy_baseline = 0.6, auc = 0.69405115, auc_precision_recall = 0.6043487, average_loss = 0.64180285, global_step = 100, label/mean = 0.4, loss = 0.64180285, precision = 0.72727275, prediction/mean = 0.3058043, recall = 0.1875
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpeqzx9get/model.ckpt-100
accuracy : 0.646875
accuracy_baseline : 0.6
auc : 0.69405115
auc_precision_recall : 0.6043487
average_loss : 0.64180285
label/mean : 0.4
loss : 0.64180285
precision : 0.72727275
prediction/mean : 0.3058043
recall : 0.1875
global_step : 100
for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpeqzx9get/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [-1.5908554]
logistic : [0.16926359]
probabilities : [0.83073646 0.16926359]
class_ids : [0]
classes : [b'0']
all_class_ids : [0 1]
all_classes : [b'0' b'1']

Beneficios de los estimadores prefabricados

Los estimadores prediseñados codifican las mejores prácticas y brindan los siguientes beneficios:

  • Mejores prácticas para determinar dónde deben ejecutarse las diferentes partes del gráfico computacional, implementando estrategias en una sola máquina o en un clúster.
  • Mejores prácticas para la redacción de eventos (resúmenes) y resúmenes de utilidad universal.

Si no utiliza Estimadores prediseñados, debe implementar las funciones anteriores usted mismo.

Estimadores personalizados

El corazón de cada Estimador, ya sea prefabricado o personalizado, es su función de modelo , model_fn , que es un método que crea gráficos para entrenamiento, evaluación y predicción. Cuando usa un Estimador prediseñado, alguien más ya implementó la función del modelo. Cuando confíe en un Estimador personalizado, debe escribir la función del modelo usted mismo.

Cree un estimador a partir de un modelo de Keras

Puede convertir modelos Keras existentes en Estimadores con tf.keras.estimator.model_to_estimator . Esto es útil si desea modernizar el código de su modelo, pero su canal de capacitación aún requiere Estimadores.

Cree una instancia de un modelo de Keras MobileNet V2 y compile el modelo con el optimizador, la pérdida y las métricas para entrenar con:

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

Cree un Estimator partir del modelo compilado de Keras. El estado del modelo inicial del modelo de Keras se conserva en el Estimator creado:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpamo11374
INFO:tensorflow:Using the Keras model provided.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/backend.py:435: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
  category=CustomMaskWarning)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpamo11374', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Trate el Estimator derivado como lo haría con cualquier otro Estimator .

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

Para entrenar, llame a la función de tren de Estimator:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpamo11374/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpamo11374/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmp/tmpamo11374/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmp/tmpamo11374/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpamo11374/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpamo11374/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.70249426, step = 0
INFO:tensorflow:loss = 0.70249426, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpamo11374/model.ckpt.
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpamo11374/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Loss for final step: 0.6834691.
INFO:tensorflow:Loss for final step: 0.6834691.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fcaf1c34cd0>

De manera similar, para evaluar, llame a la función de evaluación del Estimador:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:2426: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`Model.state_updates` will be removed in a future version. '
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-06-19T01:21:56
INFO:tensorflow:Starting evaluation at 2021-06-19T01:21:56
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpamo11374/model.ckpt-50
INFO:tensorflow:Restoring parameters from /tmp/tmpamo11374/model.ckpt-50
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 5.68835s
INFO:tensorflow:Inference Time : 5.68835s
INFO:tensorflow:Finished evaluation at 2021-06-19-01:22:02
INFO:tensorflow:Finished evaluation at 2021-06-19-01:22:02
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.478125, global_step = 50, loss = 0.671334
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.478125, global_step = 50, loss = 0.671334
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpamo11374/model.ckpt-50
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpamo11374/model.ckpt-50
{'accuracy': 0.478125, 'loss': 0.671334, 'global_step': 50}

Para obtener más detalles, consulte la documentación de tf.keras.estimator.model_to_estimator .

Guardar puntos de control basados ​​en objetos con Estimator

Los estimadores guardan de forma predeterminada los puntos de control con nombres de variables en lugar del gráfico de objetos descrito en la guía de puntos de control . tf.train.Checkpoint leerá los puntos de control basados ​​en el nombre, pero los nombres de las variables pueden cambiar cuando se mueven partes de un modelo fuera del model_fn del Estimador. Para compatibilidad con versiones posteriores, guardar puntos de control basados ​​en objetos hace que sea más fácil entrenar un modelo dentro de un Estimator y luego usarlo fuera de uno.

import tensorflow.compat.v1 as tf_compat
def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.531598, step = 0
INFO:tensorflow:loss = 4.531598, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 37.41622.
INFO:tensorflow:Loss for final step: 37.41622.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fcaf1b75210>

tf.train.Checkpoint puede cargar los puntos de control del Estimador desde su model_dir .

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

Modelos guardados de estimadores

Los estimadores exportan modelos guardados a través de tf.Estimator.export_saved_model .

input_column = tf.feature_column.numeric_column("x")

estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpnh9mbjji
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpnh9mbjji
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpnh9mbjji', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpnh9mbjji', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpnh9mbjji/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpnh9mbjji/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpnh9mbjji/model.ckpt.
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpnh9mbjji/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Loss for final step: 0.41857475.
INFO:tensorflow:Loss for final step: 0.41857475.
<tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7fca8c0e7f90>

Para guardar un Estimator , debe crear un serving_input_receiver . Esta función crea una parte de un tf.Graph que analiza los datos brutos recibidos por SavedModel.

El módulo tf.estimator.export contiene funciones para ayudar a construir estos receivers .

El siguiente código crea un receptor, basado en feature_columns , que acepta tf.Example protocolo tf.Example serializados, que a menudo se usan con tf-serve .

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Restoring parameters from /tmp/tmpnh9mbjji/model.ckpt-50
INFO:tensorflow:Restoring parameters from /tmp/tmpnh9mbjji/model.ckpt-50
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /tmp/tmpueh7p422/from_estimator/temp-1624065724/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/tmpueh7p422/from_estimator/temp-1624065724/saved_model.pb

También puede cargar y ejecutar ese modelo, desde python:

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.3068818]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.57612395]], dtype=float32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.42387608, 0.57612395]], dtype=float32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>}
{'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.1466763]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.24109668]], dtype=float32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7589033 , 0.24109669]], dtype=float32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>}

tf.estimator.export.build_raw_serving_input_receiver_fn permite crear funciones de entrada que toman tensores sin procesar en lugar de tf.train.Example .

Usando tf.distribute.Strategy con Estimator (soporte limitado)

tf.estimator es una API de TensorFlow de entrenamiento distribuido que originalmente admitía el enfoque de servidor de parámetros asíncronos. tf.estimator ahora es compatible con tf.distribute.Strategy . Si está utilizando tf.estimator , puede cambiar al entrenamiento distribuido con muy pocos cambios en su código. Con esto, los usuarios de Estimator ahora pueden realizar entrenamiento distribuido sincrónico en múltiples GPU y múltiples trabajadores, así como también usar TPU. Sin embargo, este soporte en Estimator es limitado. Consulte la sección Qué se admite ahora a continuación para obtener más detalles.

El uso de tf.distribute.Strategy con Estimator es ligeramente diferente que en el caso de Keras. En lugar de utilizar strategy.scope , ahora pasa el objeto de estrategia a RunConfig para el Estimator.

Puede consultar la guía de formación distribuida para obtener más información.

Aquí hay un fragmento de código que muestra esto con un Estimator LinearRegressor y MirroredStrategy LinearRegressor :

mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
    feature_columns=[tf.feature_column.numeric_column('feats')],
    optimizer='SGD',
    config=config)
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
WARNING:tensorflow:Collective ops is not configured at program startup. Some performance features may not be enabled.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmphjmg1q2m
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmphjmg1q2m
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmphjmg1q2m', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fca2c1de350>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fca2c1de350>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmphjmg1q2m', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fca2c1de350>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fca2c1de350>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

Aquí, usa un Estimador prediseñado, pero el mismo código también funciona con un Estimador personalizado. train_distribute determina cómo se distribuirá el entrenamiento y eval_distribute determina cómo se distribuirá la evaluación. Esta es otra diferencia con Keras, donde se usa la misma estrategia tanto para el entrenamiento como para la evaluación.

Ahora puede entrenar y evaluar este Estimador con una función de entrada:

def input_fn():
  dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
  return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmphjmg1q2m/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmphjmg1q2m/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmphjmg1q2m/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmphjmg1q2m/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 2.877698e-13.
INFO:tensorflow:Loss for final step: 2.877698e-13.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-06-19T01:22:08
INFO:tensorflow:Starting evaluation at 2021-06-19T01:22:08
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmphjmg1q2m/model.ckpt-10
INFO:tensorflow:Restoring parameters from /tmp/tmphjmg1q2m/model.ckpt-10
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.22064s
INFO:tensorflow:Inference Time : 0.22064s
INFO:tensorflow:Finished evaluation at 2021-06-19-01:22:08
INFO:tensorflow:Finished evaluation at 2021-06-19-01:22:08
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmphjmg1q2m/model.ckpt-10
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmphjmg1q2m/model.ckpt-10
{'average_loss': 1.4210855e-14,
 'label/mean': 1.0,
 'loss': 1.4210855e-14,
 'prediction/mean': 0.99999994,
 'global_step': 10}

Otra diferencia a destacar aquí entre Estimator y Keras es el manejo de entrada. En Keras, cada lote del conjunto de datos se divide automáticamente entre las múltiples réplicas. En Estimator, sin embargo, no realiza la división automática por lotes, ni fragmenta automáticamente los datos entre diferentes trabajadores. Tiene control total sobre cómo desea que se distribuyan sus datos entre los trabajadores y los dispositivos, y debe proporcionar un input_fn para especificar cómo distribuir sus datos.

Su input_fn se llama una vez por trabajador, lo que proporciona un conjunto de datos por trabajador. Luego, un lote de ese conjunto de datos se alimenta a una réplica en ese trabajador, consumiendo así N lotes para N réplicas en 1 trabajador. En otras palabras, el conjunto de datos devuelto por input_fn debe proporcionar lotes de tamaño PER_REPLICA_BATCH_SIZE . Y el tamaño de lote global para un paso se puede obtener como PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync .

Al realizar la capacitación de varios trabajadores, debe dividir sus datos entre los trabajadores o mezclarlos con una semilla aleatoria en cada uno. Puede consultar un ejemplo de cómo hacer esto en el tutorial Capacitación para varios trabajadores con Estimator .

Y de manera similar, también puede usar estrategias de servidor de parámetros y de trabajadores múltiples. El código sigue siendo el mismo, pero debe usar tf.estimator.train_and_evaluate y configurar las variables de entorno TF_CONFIG para cada binario que se ejecute en su clúster.

¿Qué es compatible ahora?

Hay soporte limitado para entrenar con Estimator usando todas las estrategias excepto TPUStrategy . La capacitación y la evaluación básicas deberían funcionar, pero una serie de funciones avanzadas como v1.train.Scaffold no lo hacen. También puede haber una serie de errores en esta integración y no hay planes para mejorar activamente este soporte (la atención se centra en Keras y el soporte de bucle de entrenamiento personalizado). Si es posible, debería preferir utilizar tf.distribute con esas API.

API de entrenamiento EspejoEstrategia TPUStrategy MultiWorkerMirroredStrategy CentralStorageStrategy ParameterServerStrategy
API de estimador Soporte limitado No soportado Soporte limitado Soporte limitado Soporte limitado

Ejemplos y tutoriales

A continuación, se muestran algunos ejemplos de un extremo a otro que muestran cómo utilizar varias estrategias con Estimator:

  1. El tutorial Capacitación de varios trabajadores con Estimator muestra cómo se puede capacitar con varios trabajadores mediante MultiWorkerMirroredStrategy en el conjunto de datos MNIST.
  2. Un ejemplo de un extremo a otro de la ejecución de capacitación de varios trabajadores con estrategias de distribución en tensorflow/ecosystem utilizando plantillas de Kubernetes. Comienza con un modelo de Keras y lo convierte en un Estimador usando la API tf.keras.estimator.model_to_estimator .
  3. El modelo oficial de ResNet50 , que se puede entrenar utilizando MirroredStrategy o MultiWorkerMirroredStrategy .