Estimators

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This document introduces tf.estimator—a high-level TensorFlow API. Estimators encapsulate the following actions:

  • training
  • evaluation
  • prediction
  • export for serving

TensorFlow implements several pre-made Estimators. Custom estimators are still suported, but mainly as a backwards compatibility measure. Custom estimators should not be used for new code. All Estimators--whether pre-made or custom--are classes based on the tf.estimator.Estimator class.

For a quick example try Estimator tutorials. For an overview of the API design, see the white paper.

Setup

 pip install -q -U tensorflow_datasets
import tempfile
import os

import tensorflow as tf
import tensorflow_datasets as tfds

Advantages

Similar to a tf.keras.Model, an estimator is a model-level abstraction. The tf.estimator provides some capabilities currently still under development for tf.keras. These are:

  • Parameter server based training
  • Full TFX integration.

Estimators Capabilities

Estimators provide the following benefits:

  • You can run Estimator-based models on a local host or on a distributed multi-server environment without changing your model. Furthermore, you can run Estimator-based models on CPUs, GPUs, or TPUs without recoding your model.
  • Estimators provide a safe distributed training loop that controls how and when to:
    • load data
    • handle exceptions
    • create checkpoint files and recover from failures
    • save summaries for TensorBoard

When writing an application with Estimators, you must separate the data input pipeline from the model. This separation simplifies experiments with different data sets.

Using pre-made Estimators

Pre-made Estimators enable you to work at a much higher conceptual level than the base TensorFlow APIs. You no longer have to worry about creating the computational graph or sessions since Estimators handle all the "plumbing" for you. Furthermore, pre-made Estimators let you experiment with different model architectures by making only minimal code changes. tf.estimator.DNNClassifier, for example, is a pre-made Estimator class that trains classification models based on dense, feed-forward neural networks.

A TensorFlow program relying on a pre-made Estimator typically consists of the following four steps:

1. Write an input functions

For example, you might create one function to import the training set and another function to import the test set. Estimators expect their inputs to be formatted as a pair of objects:

  • A dictionary in which the keys are feature names and the values are Tensors (or SparseTensors) containing the corresponding feature data
  • A Tensor containing one or more labels

The input_fn should return a tf.data.Dataset that yields pairs in that format.

For example, the following code builds a tf.data.Dataset from the Titanic dataset's train.csv file:

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.experimental.AUTOTUNE))
  return titanic_batches

The input_fn is executed in a tf.Graph and can also directly return a (features_dics, labels) pair containing graph tensors, but this is error prone outside of simple cases like returning constants.

2. Define the feature columns.

Each tf.feature_column identifies a feature name, its type, and any input pre-processing.

For example, the following snippet creates three feature columns.

  • The first uses the age feature directly as a floating-point input.
  • The second uses the class feature as a categorical input.
  • The third uses the embark_town as a categorical input, but uses the hashing trick to avoid the need to enumerate the options, and to set the number of options.

For further information, see the feature columns tutorial.

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3. Instantiate the relevant pre-made Estimator.

For example, here's a sample instantiation of a pre-made Estimator named LinearClassifier:

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp_2fgw1gd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

For further information, see the linear classifier tutorial.

4. Call a training, evaluation, or inference method.

All Estimators provide train, evaluate, and predict methods.

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:1481: Layer.add_variable (from tensorflow.python.keras.engine.base_layer_v1) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:112: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp_2fgw1gd/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmp_2fgw1gd/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.6098593.

result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-10-15T01:25:18Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp_2fgw1gd/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.63935s
INFO:tensorflow:Finished evaluation at 2020-10-15-01:25:19
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.7, accuracy_baseline = 0.603125, auc = 0.70968133, auc_precision_recall = 0.6162292, average_loss = 0.6068252, global_step = 100, label/mean = 0.396875, loss = 0.6068252, precision = 0.6962025, prediction/mean = 0.3867289, recall = 0.43307087
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmp_2fgw1gd/model.ckpt-100
accuracy : 0.7
accuracy_baseline : 0.603125
auc : 0.70968133
auc_precision_recall : 0.6162292
average_loss : 0.6068252
label/mean : 0.396875
loss : 0.6068252
precision : 0.6962025
prediction/mean : 0.3867289
recall : 0.43307087
global_step : 100

for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp_2fgw1gd/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [0.6824188]
logistic : [0.6642783]
probabilities : [0.33572164 0.6642783 ]
class_ids : [1]
classes : [b'1']
all_class_ids : [0 1]
all_classes : [b'0' b'1']

Benefits of pre-made Estimators

Pre-made Estimators encode best practices, providing the following benefits:

  • Best practices for determining where different parts of the computational graph should run, implementing strategies on a single machine or on a cluster.
  • Best practices for event (summary) writing and universally useful summaries.

If you don't use pre-made Estimators, you must implement the preceding features yourself.

Custom Estimators

The heart of every Estimator—whether pre-made or custom—is its model function, model_fn, which is a method that builds graphs for training, evaluation, and prediction. When you are using a pre-made Estimator, someone else has already implemented the model function. When relying on a custom Estimator, you must write the model function yourself.

Create an Estimator from a Keras model

You can convert existing Keras models to Estimators with tf.keras.estimator.model_to_estimator. This is helpful if you want to modernize your model code, but your training pipeline still requires Estimators.

Instantiate a Keras MobileNet V2 model and compile the model with the optimizer, loss, and metrics to train with:

import tensorflow as tf import tensorflow_datasets as tfds

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

Create an Estimator from the compiled Keras model. The initial model state of the Keras model is preserved in the created Estimator:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpy7tafocp
INFO:tensorflow:Using the Keras model provided.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py:220: set_learning_phase (from tensorflow.python.keras.backend) is deprecated and will be removed after 2020-10-11.
Instructions for updating:
Simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpy7tafocp', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Treat the derived Estimator as you would with any other Estimator.

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

To train, call Estimator's train function:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...

Warning:absl:1738 images were corrupted and were skipped

Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteMUGW8X/cats_vs_dogs-train.tfrecord
Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpy7tafocp/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpy7tafocp/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting from: /tmp/tmpy7tafocp/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting from: /tmp/tmpy7tafocp/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.68286127, step = 0

INFO:tensorflow:loss = 0.68286127, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Loss for final step: 0.70231926.

INFO:tensorflow:Loss for final step: 0.70231926.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f1a6c4d4cf8>

Similarly, to evaluate, call the Estimator's evaluate function:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:26Z

INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:26Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/tmpy7tafocp/model.ckpt-50

INFO:tensorflow:Restoring parameters from /tmp/tmpy7tafocp/model.ckpt-50

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Inference Time : 1.92025s

INFO:tensorflow:Inference Time : 1.92025s

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:28

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:28

INFO:tensorflow:Saving dict for global step 50: accuracy = 0.565625, global_step = 50, loss = 0.6713216

INFO:tensorflow:Saving dict for global step 50: accuracy = 0.565625, global_step = 50, loss = 0.6713216

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpy7tafocp/model.ckpt-50

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpy7tafocp/model.ckpt-50

{'accuracy': 0.565625, 'loss': 0.6713216, 'global_step': 50}

For more details, please refer to the documentation for tf.keras.estimator.model_to_estimator.

Saving object-based checkpoints with Estimator

Estimators by default save checkpoints with variable names rather than the object graph described in the Checkpoint guide. tf.train.Checkpoint will read name-based checkpoints, but variable names may change when moving parts of a model outside of the Estimator's model_fn. For forwards compatibility saving object-based checkpoints makes it easier to train a model inside an Estimator and then use it outside of one.

import tensorflow.compat.v1 as tf_compat
def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.

INFO:tensorflow:Using default config.

INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 4.5633583, step = 0

INFO:tensorflow:loss = 4.5633583, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Loss for final step: 37.95615.

INFO:tensorflow:Loss for final step: 37.95615.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f1a6c477630>

tf.train.Checkpoint can then load the Estimator's checkpoints from its model_dir.

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

SavedModels from Estimators

Estimators export SavedModels through tf.Estimator.export_saved_model.

input_column = tf.feature_column.numeric_column("x")

estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config.

INFO:tensorflow:Using default config.

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpn_8rzqza

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpn_8rzqza

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpn_8rzqza', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpn_8rzqza', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.6931472, step = 0

INFO:tensorflow:loss = 0.6931472, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Loss for final step: 0.35876164.

INFO:tensorflow:Loss for final step: 0.35876164.

<tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f1a6c448b00>

To save an Estimator you need to create a serving_input_receiver. This function builds a part of a tf.Graph that parses the raw data received by the SavedModel.

The tf.estimator.export module contains functions to help build these receivers.

The following code builds a receiver, based on the feature_columns, that accepts serialized tf.Example protocol buffers, which are often used with tf-serving.

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']

INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']

INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']

INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Restoring parameters from /tmp/tmpn_8rzqza/model.ckpt-50

INFO:tensorflow:Restoring parameters from /tmp/tmpn_8rzqza/model.ckpt-50

INFO:tensorflow:Assets added to graph.

INFO:tensorflow:Assets added to graph.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:SavedModel written to: /tmp/tmptcppevt7/from_estimator/temp-1602725189/saved_model.pb

INFO:tensorflow:SavedModel written to: /tmp/tmptcppevt7/from_estimator/temp-1602725189/saved_model.pb

You can also load and run that model, from python:

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.43590346, 0.5640965 ]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2578045]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.5640965]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>}
{'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7984398 , 0.20156018]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.3765715]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2015602]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>}

tf.estimator.export.build_raw_serving_input_receiver_fn allows you to create input functions which take raw tensors rather than tf.train.Examples.

Using tf.distribute.Strategy with Estimator (Limited support)

See the Distributed training guide for more info.

tf.estimator is a distributed training TensorFlow API that originally supported the async parameter server approach. tf.estimator now supports tf.distribute.Strategy. If you're using tf.estimator, you can change to distributed training with very few changes to your code. With this, Estimator users can now do synchronous distributed training on multiple GPUs and multiple workers, as well as use TPUs. This support in Estimator is, however, limited. See What's supported now section below for more details.

The usage of tf.distribute.Strategy with Estimator is slightly different than the Keras case. Instead of using strategy.scope, now we pass the strategy object into the RunConfig for the Estimator.

Here is a snippet of code that shows this with a premade Estimator LinearRegressor and MirroredStrategy:

mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
    feature_columns=[tf.feature_column.numeric_column('feats')],
    optimizer='SGD',
    config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Initializing RunConfig with distribution strategies.

INFO:tensorflow:Initializing RunConfig with distribution strategies.

INFO:tensorflow:Not using Distribute Coordinator.

INFO:tensorflow:Not using Distribute Coordinator.

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmphb70j0wf

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmphb70j0wf

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmphb70j0wf', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmphb70j0wf', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

We use a premade Estimator here, but the same code works with a custom Estimator as well. train_distribute determines how training will be distributed, and eval_distribute determines how evaluation will be distributed. This is another difference from Keras where we use the same strategy for both training and eval.

Now we can train and evaluate this Estimator with an input function:

def input_fn():
  dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
  return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:339: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:339: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1b28ec8b70> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1b28ec8b70> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning: AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1b28ec8b70> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 1.0, step = 0

INFO:tensorflow:loss = 1.0, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Loss for final step: 2.877698e-13.

INFO:tensorflow:Loss for final step: 2.877698e-13.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1ad8062d08> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1ad8062d08> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning: AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1ad8062d08> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:34Z

INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:34Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/tmphb70j0wf/model.ckpt-10

INFO:tensorflow:Restoring parameters from /tmp/tmphb70j0wf/model.ckpt-10

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Inference Time : 0.23888s

INFO:tensorflow:Inference Time : 0.23888s

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:34

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:34

INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994

INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmphb70j0wf/model.ckpt-10

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmphb70j0wf/model.ckpt-10

{'average_loss': 1.4210855e-14,
 'label/mean': 1.0,
 'loss': 1.4210855e-14,
 'prediction/mean': 0.99999994,
 'global_step': 10}

Another difference to highlight here between Estimator and Keras is the input handling. In Keras, we mentioned that each batch of the dataset is split automatically across the multiple replicas. In Estimator, however, we do not do automatic splitting of batch, nor automatically shard the data across different workers. You have full control over how you want your data to be distributed across workers and devices, and you must provide an input_fn to specify how to distribute your data.

Your input_fn is called once per worker, thus giving one dataset per worker. Then one batch from that dataset is fed to one replica on that worker, thereby consuming N batches for N replicas on 1 worker. In other words, the dataset returned by the input_fn should provide batches of size PER_REPLICA_BATCH_SIZE. And the global batch size for a step can be obtained as PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync.

When doing multi worker training, you should either split your data across the workers, or shuffle with a random seed on each. You can see an example of how to do this in the Multi-worker Training with Estimator.

And similarly, you can use multi worker and parameter server strategies as well. The code remains the same, but you need to use tf.estimator.train_and_evaluate, and set TF_CONFIG environment variables for each binary running in your cluster.

What's supported now?

There is limited support for training with Estimator using all strategies except TPUStrategy. Basic training and evaluation should work, but a number of advanced features such as v1.train.Scaffold do not. There may also be a number of bugs in this integration. At this time, we do not plan to actively improve this support, and instead are focused on Keras and custom training loop support. If at all possible, you should prefer to use tf.distribute with those APIs instead.

Training API MirroredStrategy TPUStrategy MultiWorkerMirroredStrategy CentralStorageStrategy ParameterServerStrategy
Estimator API Limited Support Not supported Limited Support Limited Support Limited Support

Examples and Tutorials

Here are some examples that show end to end usage of various strategies with Estimator:

  1. Multi-worker Training with Estimator to train MNIST with multiple workers using MultiWorkerMirroredStrategy.
  2. End to end example for multi worker training in tensorflow/ecosystem using Kubernetes templates. This example starts with a Keras model and converts it to an Estimator using the tf.keras.estimator.model_to_estimator API.
  3. Official ResNet50 model, which can be trained using either MirroredStrategy or MultiWorkerMirroredStrategy.