Constructs an Estimator instance from given keras model.

Used in the notebooks

Used in the guide Used in the tutorials

If you use infrastructure or other tooling that relies on Estimators, you can still build a Keras model and use model_to_estimator to convert the Keras model to an Estimator for use with downstream systems.

For usage example, please see: Creating estimators from Keras Models.

Sample Weights:

Estimators returned by model_to_estimator are configured so that they can handle sample weights (similar to, y, sample_weights)).

To pass sample weights when training or evaluating the Estimator, the first item returned by the input function should be a dictionary with keys features and sample_weights. Example below:

keras_model = tf.keras.Model(...)

estimator = tf.keras.estimator.model_to_estimator(keras_model)

def input_fn():
  return dataset_ops.Dataset.from_tensors(
      ({'features': features, 'sample_weights': sample_weights},

estimator.train(input_fn, steps=1)

To customize the estimator eval_metric_ops names, you can pass in the metric_names_map dictionary mapping the keras model output metric names to the custom names as follows:

  input_a = tf.keras.layers.Input(shape=(16,), name='input_a')
  input_b = tf.keras.layers.Input(shape=(16,), name='input_b')
  dense = tf.keras.layers.Dense(8, name='dense_1')
  interm_a = dense(input_a)
  interm_b = dense(input_b)
  merged = tf.keras.layers.concatenate([interm_a, interm_b], name='merge')
  output_a = tf.keras.layers.Dense(3, activation='softmax', name='dense_2')(
  output_b = tf.keras.layers.Dense(2, activation='softmax', name='dense_3')(
  keras_model = tf.keras.models.Model(
      inputs=[input_a, input_b], outputs=[output_a, output_b])
          'dense_2': 'categorical_accuracy',
          'dense_3': 'categorical_accuracy'

  metric_names_map = {
      'dense_2_categorical_accuracy': 'acc_1',
      'dense_3_categorical_accuracy': 'acc_2',
  keras_est = tf.keras.estimator.model_to_estimator(

keras_model A compiled Keras model object. This argument is mutually exclusive with keras_model_path. Estimator's model_fn uses the structure of the model to clone the model. Defaults to None.
keras_model_path Path to a compiled Keras model saved on disk, in HDF5 format, which can be generated with the save() method of a Keras model. This argument is mutually exclusive with keras_model. Defaults to None.
custom_objects Dictionary for cloning customized objects. This is used with classes that is not part of this pip package. For example, if user maintains a relu6 class that inherits from tf.keras.layers.Layer, then pass custom_objects={'relu6': relu6}. Defaults to None.
model_dir Directory to save Estimator model parameters, graph, summary files for TensorBoard, etc. If unset a directory will be created with tempfile.mkdtemp
config RunConfig to config Estimator. Allows setting up things in model_fn based on configuration such as num_ps_replicas, or model_dir. Defaults to None. If both config.model_dir and the model_dir argument (above) are specified the model_dir argument takes precedence.
checkpoint_format Sets the format of the checkpoint saved by the estimator when training. May be saver or checkpoint, depending on whether to save checkpoints from tf.compat.v1.train.Saver or tf.train.Checkpoint. The default is checkpoint. Estimators use name-based tf.train.Saver checkpoints, while Keras models use object-based checkpoints from tf.train.Checkpoint. Currently,