Missed TensorFlow World? Check out the recap. Learn more

Create an Estimator from a Keras model

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

TensorFlow Estimators are fully supported in TensorFlow, and can be created from new and existing tf.keras models. This tutorial contains a complete, minimal example of that process.

Setup

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf

import numpy as np
import tensorflow_datasets as tfds

Create a simple Keras model.

In Keras, you assemble layers to build models. A model is (usually) a graph of layers. The most common type of model is a stack of layers: the tf.keras.Sequential model.

To build a simple, fully-connected network (i.e. multi-layer perceptron):

model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

Compile the model and get a summary.

model.compile(loss='categorical_crossentropy', optimizer='adam')
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                80        
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
=================================================================
Total params: 97
Trainable params: 97
Non-trainable params: 0
_________________________________________________________________

Create an input function

Use the Datasets API to scale to large datasets or multi-device training.

Estimators need control of when and how their input pipeline is built. To allow this, they require an "Input function" or input_fn. The Estimator will call this function with no arguments. The input_fn must return a tf.data.Dataset.

def input_fn():
  split = tfds.Split.TRAIN
  dataset = tfds.load('iris', split=split, as_supervised=True)
  dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))
  dataset = dataset.batch(32).repeat()
  return dataset

Test out your input_fn

for features_batch, labels_batch in input_fn().take(1):
  print(features_batch)
  print(labels_batch)
Downloading and preparing dataset iris (4.44 KiB) to /home/kbuilder/tensorflow_datasets/iris/1.0.0...

HBox(children=(IntProgress(value=1, bar_style='info', description='Dl Completed...', max=1, style=ProgressStyl…
HBox(children=(IntProgress(value=1, bar_style='info', description='Dl Size...', max=1, style=ProgressStyle(des…





/home/kbuilder/.local/lib/python3.5/site-packages/urllib3/connectionpool.py:1004: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning,

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


HBox(children=(IntProgress(value=0, description='Shuffling...', max=1, style=ProgressStyle(description_width='…
WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.5/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…
HBox(children=(IntProgress(value=0, description='Writing...', max=150, style=ProgressStyle(description_width='…
Dataset iris downloaded and prepared to /home/kbuilder/tensorflow_datasets/iris/1.0.0. Subsequent calls will reuse this data.
{'dense_input': <tf.Tensor: id=200, shape=(32, 4), dtype=float32, numpy=
array([[6.1, 2.8, 4.7, 1.2],
       [5.7, 3.8, 1.7, 0.3],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.9, 4.5, 1.5],
       [6.8, 2.8, 4.8, 1.4],
       [5.4, 3.4, 1.5, 0.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.9, 3.1, 5.1, 2.3],
       [6.2, 2.2, 4.5, 1.5],
       [5.8, 2.7, 3.9, 1.2],
       [6.5, 3.2, 5.1, 2. ],
       [4.8, 3. , 1.4, 0.1],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.1, 3.8, 1.5, 0.3],
       [6.3, 3.3, 4.7, 1.6],
       [6.5, 3. , 5.8, 2.2],
       [5.6, 2.5, 3.9, 1.1],
       [5.7, 2.8, 4.5, 1.3],
       [6.4, 2.8, 5.6, 2.2],
       [4.7, 3.2, 1.6, 0.2],
       [6.1, 3. , 4.9, 1.8],
       [5. , 3.4, 1.6, 0.4],
       [6.4, 2.8, 5.6, 2.1],
       [7.9, 3.8, 6.4, 2. ],
       [6.7, 3. , 5.2, 2.3],
       [6.7, 2.5, 5.8, 1.8],
       [6.8, 3.2, 5.9, 2.3],
       [4.8, 3. , 1.4, 0.3],
       [4.8, 3.1, 1.6, 0.2],
       [4.6, 3.6, 1. , 0.2],
       [5.7, 4.4, 1.5, 0.4]], dtype=float32)>}
tf.Tensor([1 0 2 1 1 0 1 2 1 1 2 0 0 0 0 1 2 1 1 2 0 2 0 2 2 2 2 2 0 0 0 0], shape=(32,), dtype=int64)

Create an Estimator from the tf.keras model.

A tf.keras.Model can be trained with the tf.estimator API by converting the model to an tf.estimator.Estimator object with tf.keras.estimator.model_to_estimator.

model_dir = "/tmp/tfkeras_example/"
keras_estimator = tf.keras.estimator.model_to_estimator(
    keras_model=model, model_dir=model_dir)
INFO:tensorflow:Using default config.

INFO:tensorflow:Using default config.

INFO:tensorflow:Using the Keras model provided.

INFO:tensorflow:Using the Keras model provided.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.

INFO:tensorflow:Using config: {'_num_ps_replicas': 0, '_session_creation_timeout_secs': 7200, '_global_id_in_cluster': 0, '_save_checkpoints_secs': 600, '_save_checkpoints_steps': None, '_protocol': None, '_log_step_count_steps': 100, '_num_worker_replicas': 1, '_master': '', '_task_id': 0, '_train_distribute': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_is_chief': True, '_experimental_max_worker_delay_secs': None, '_experimental_distribute': None, '_service': None, '_task_type': 'worker', '_eval_distribute': None, '_tf_random_seed': None, '_evaluation_master': '', '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fd37006ed30>, '_device_fn': None, '_save_summary_steps': 100, '_model_dir': '/tmp/tfkeras_example/', '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
}

INFO:tensorflow:Using config: {'_num_ps_replicas': 0, '_session_creation_timeout_secs': 7200, '_global_id_in_cluster': 0, '_save_checkpoints_secs': 600, '_save_checkpoints_steps': None, '_protocol': None, '_log_step_count_steps': 100, '_num_worker_replicas': 1, '_master': '', '_task_id': 0, '_train_distribute': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_is_chief': True, '_experimental_max_worker_delay_secs': None, '_experimental_distribute': None, '_service': None, '_task_type': 'worker', '_eval_distribute': None, '_tf_random_seed': None, '_evaluation_master': '', '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fd37006ed30>, '_device_fn': None, '_save_summary_steps': 100, '_model_dir': '/tmp/tfkeras_example/', '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
}

Train and evaluate the estimator.

keras_estimator.train(input_fn=input_fn, steps=25)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print('Eval result: {}'.format(eval_result))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.5/site-packages/tensorflow_core/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tfkeras_example/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tfkeras_example/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting from: /tmp/tfkeras_example/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting from: /tmp/tfkeras_example/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-started 4 variables.

INFO:tensorflow:Warm-started 4 variables.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tfkeras_example/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tfkeras_example/model.ckpt.

INFO:tensorflow:loss = 118.05543, step = 0

INFO:tensorflow:loss = 118.05543, step = 0

INFO:tensorflow:Saving checkpoints for 25 into /tmp/tfkeras_example/model.ckpt.

INFO:tensorflow:Saving checkpoints for 25 into /tmp/tfkeras_example/model.ckpt.

INFO:tensorflow:Loss for final step: 84.80701.

INFO:tensorflow:Loss for final step: 84.80701.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Starting evaluation at 2019-10-01T01:26:35Z

INFO:tensorflow:Starting evaluation at 2019-10-01T01:26:35Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/tfkeras_example/model.ckpt-25

INFO:tensorflow:Restoring parameters from /tmp/tfkeras_example/model.ckpt-25

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Finished evaluation at 2019-10-01-01:26:36

INFO:tensorflow:Finished evaluation at 2019-10-01-01:26:36

INFO:tensorflow:Saving dict for global step 25: global_step = 25, loss = 101.27602

INFO:tensorflow:Saving dict for global step 25: global_step = 25, loss = 101.27602

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tfkeras_example/model.ckpt-25

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tfkeras_example/model.ckpt-25

Eval result: {'global_step': 25, 'loss': 101.27602}