Join us at TensorFlow World, Oct 28-31. Use code TF20 for 20% off select passes. Register now


View on Run in Google Colab View source on GitHub Download notebook

This document introduces tf.estimator—a high-level TensorFlow API that greatly simplifies machine learning programming. Estimators encapsulate the following actions:

  • training
  • evaluation
  • prediction
  • export for serving

You may either use the pre-made Estimators we provide or write your own custom Estimators. All Estimators—whether pre-made or custom—are classes based on the tf.estimator.Estimator class.

For a quick example try Estimator tutorials. For an overview of the API design, see the white paper.

Estimator advantages

Estimators provide the following benefits:

  • You can run Estimator-based models on a local host or on a distributed multi-server environment without changing your model. Furthermore, you can run Estimator-based models on CPUs, GPUs, or TPUs without recoding your model.
  • Estimators simplify sharing implementations between model developers.
  • You can develop a state of the art model with high-level intuitive code. In short, it is generally much easier to create models with Estimators than with the low-level TensorFlow APIs.
  • Estimators are themselves built on tf.keras.layers, which simplifies customization.
  • Estimators build the graph for you.
  • Estimators provide a safe distributed training loop that controls how and when to:
    • build the graph
    • initialize variables
    • load data
    • handle exceptions
    • create checkpoint files and recover from failures
    • save summaries for TensorBoard

When writing an application with Estimators, you must separate the data input pipeline from the model. This separation simplifies experiments with different data sets.

Pre-made Estimators

Pre-made Estimators enable you to work at a much higher conceptual level than the base TensorFlow APIs. You no longer have to worry about creating the computational graph or sessions since Estimators handle all the "plumbing" for you. Furthermore, pre-made Estimators let you experiment with different model architectures by making only minimal code changes. tf.estimator.DNNClassifier, for example, is a pre-made Estimator class that trains classification models based on dense, feed-forward neural networks.

Structure of a pre-made Estimators program

A TensorFlow program relying on a pre-made Estimator typically consists of the following four steps:

1. Write one or more dataset importing functions.

For example, you might create one function to import the training set and another function to import the test set. Each dataset importing function must return two objects:

  • a dictionary in which the keys are feature names and the values are Tensors (or SparseTensors) containing the corresponding feature data
  • a Tensor containing one or more labels

For example, the following code illustrates the basic skeleton for an input function:

def input_fn(dataset):
    ...  # manipulate dataset, extracting the feature dict and the label
    return feature_dict, label

See data guide for details.

2. Define the feature columns.

Each tf.feature_column identifies a feature name, its type, and any input pre-processing. For example, the following snippet creates three feature columns that hold integer or floating-point data. The first two feature columns simply identify the feature's name and type. The third feature column also specifies a lambda the program will invoke to scale the raw data:

# Define three numeric feature columns.
population = tf.feature_column.numeric_column('population')
crime_rate = tf.feature_column.numeric_column('crime_rate')
median_education = tf.feature_column.numeric_column(
  normalizer_fn=lambda x: x - global_education_mean)

For further information, it is recommended to check this tutorial.

3. Instantiate the relevant pre-made Estimator.

For example, here's a sample instantiation of a pre-made Estimator named LinearClassifier:

# Instantiate an estimator, passing the feature columns.
estimator = tf.estimator.LinearClassifier(
  feature_columns=[population, crime_rate, median_education])

For further information, it is recommended to check this tutorial.

4. Call a training, evaluation, or inference method.

For example, all Estimators provide a train method, which trains a model.

# `input_fn` is the function created in Step 1
estimator.train(input_fn=my_training_set, steps=2000)

You can see an example of this below.

Benefits of pre-made Estimators

Pre-made Estimators encode best practices, providing the following benefits:

  • Best practices for determining where different parts of the computational graph should run, implementing strategies on a single machine or on a cluster.
  • Best practices for event (summary) writing and universally useful summaries.

If you don't use pre-made Estimators, you must implement the preceding features yourself.

Custom Estimators

The heart of every Estimator—whether pre-made or custom—is its model function, which is a method that builds graphs for training, evaluation, and prediction. When you are using a pre-made Estimator, someone else has already implemented the model function. When relying on a custom Estimator, you must write the model function yourself.

  1. Assuming a suitable pre-made Estimator exists, use it to build your first model and use its results to establish a baseline.
  2. Build and test your overall pipeline, including the integrity and reliability of your data with this pre-made Estimator.
  3. If suitable alternative pre-made Estimators are available, run experiments to determine which pre-made Estimator produces the best results.
  4. Possibly, further improve your model by building your own custom Estimator.
from __future__ import absolute_import, division, print_function, unicode_literals
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
import tensorflow as tf
import tensorflow_datasets as tfds

Create an Estimator from a Keras model

You can convert existing Keras models to Estimators with tf.keras.estimator.model_to_estimator. Doing so enables your Keras model to access Estimator's strengths, such as distributed training.

Instantiate a Keras MobileNet V2 model and compile the model with the optimizer, loss, and metrics to train with:

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)

estimator_model = tf.keras.Sequential([
    tf.keras.layers.Dense(1, activation='softmax')

# Compile the model
Downloading data from
9412608/9406464 [==============================] - 2s 0us/step

Create an Estimator from the compiled Keras model. The initial model state of the Keras model is preserved in the created Estimator:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
WARNING: Logging before flag parsing goes to stderr.
W0817 00:51:46.436726 140010235815680] Using temporary folder as model directory: /tmp/tmp2z_11ccy
W0817 00:51:52.967502 140010235815680] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/ops/ add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Treat the derived Estimator as you would with any other Estimator.

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data =
  return train_data

To train, call Estimator's train function:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=500)
W0817 00:51:59.754434 140010235815680] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/training/ Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.

Downloading and preparing dataset cats_vs_dogs (786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/2.0.1...

/home/kbuilder/.local/lib/python3.5/site-packages/urllib3/ InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See:
W0817 00:52:44.137132 140010235815680] 1738 images were corrupted and were skipped
W0817 00:52:44.149461 140010235815680] From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/ tf_record_iterator (from is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
W0817 00:52:47.005725 140010235815680] Warning: Setting shuffle_files=True because split=TRAIN and shuffle_files=None. This behavior will be deprecated on 2019-08-06, at which point shuffle_files=False will be the default for all splits.

Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/2.0.1. Subsequent calls will reuse this data.

<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f55311bceb8>

Similarly, to evaluate, call the Estimator's evaluate function:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
W0817 00:54:11.226737 140010235815680] Warning: Setting shuffle_files=True because split=TRAIN and shuffle_files=None. This behavior will be deprecated on 2019-08-06, at which point shuffle_files=False will be the default for all splits.
W0817 00:54:14.660813 140010235815680] From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow/python/training/ checkpoint_exists (from is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.

{'global_step': 500, 'loss': 8.049951}

For more details, please refer to the documentation for tf.keras.estimator.model_to_estimator.