Migrate evaluation

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Evaluation is a critical part of measuring and benchmarking models.

This guide demonstrates how to migrate evaluator tasks from TensorFlow 1 to TensorFlow 2. In Tensorflow 1 this functionality is implemented by tf.estimator.train_and_evaluate, when the API is running distributedly. In Tensorflow 2, you can use the built-in tf.keras.utils.SidecarEvaluator, or a custom evaluation loop on the evaluator task.

There are simple serial evaluation options in both TensorFlow 1 (tf.estimator.Estimator.evaluate) and TensorFlow 2 (Model.fit(..., validation_data=(...)) or Model.evaluate). The evaluator task is preferable when you would like your workers not switching between training and evaluation, and built-in evaluation in Model.fit is preferable when you would like your evaluation to be distributed.

Setup

import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
import os
2024-02-14 02:22:57.545197: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-14 02:22:57.545243: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-14 02:22:57.546705: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 [==============================] - 0s 0us/step

TensorFlow 1: Evaluating using tf.estimator.train_and_evaluate

In TensorFlow 1, you can configure a tf.estimator to evaluate the estimator using tf.estimator.train_and_evaluate.

In this example, start by defining the tf.estimator.Estimator and specifying training and evaluation specifications:

feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]

classifier = tf1.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[256, 32],
    optimizer=tf1.train.AdamOptimizer(0.001),
    n_classes=10,
    dropout=0.2
)

train_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_train},
    y=y_train.astype(np.int32),
    num_epochs=10,
    batch_size=50,
    shuffle=True,
)

test_input_fn = tf1.estimator.inputs.numpy_input_fn(
    x={"x": x_test},
    y=y_test.astype(np.int32),
    num_epochs=10,
    shuffle=False
)

train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
                                   steps=10,
                                   throttle_secs=0)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_12475/122738158.py:1: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.
Instructions for updating:
Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_12475/122738158.py:3: DNNClassifier.__init__ (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:807: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp_1kb1itp
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp_1kb1itp', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_12475/122738158.py:11: numpy_input_fn (from tensorflow_estimator.python.estimator.inputs.numpy_io) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_12475/122738158.py:26: TrainSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_12475/122738158.py:27: EvalSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.

Then, train and evaluate the model. The evaluation runs synchronously between training because it's limited as a local run in this notebook and alternates between training and evaluation. However, if the estimator is used distributedly, the evaluator will run as a dedicated evaluator task. For more information, check the migration guide on distributed training.

tf1.estimator.train_and_evaluate(estimator=classifier,
                                train_spec=train_spec,
                                eval_spec=eval_spec)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_12475/3761202737.py:1: train_and_evaluate (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:60: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:491: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:446: dnn_logit_fn_builder (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:910: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
2024-02-14 02:23:03.118506: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT64
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_INT32
    }
  }
}

    for Tuple type infernce function 0
    while inferring type of node 'dnn/zero_fraction/cond/output/_18'
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp_1kb1itp/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:loss = 118.97304, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmpfs/tmp/tmp_1kb1itp/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2024-02-14T02:23:04
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp_1kb1itp/model.ckpt-10
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.78981s
INFO:tensorflow:Finished evaluation at 2024-02-14-02:23:05
INFO:tensorflow:Saving dict for global step 10: accuracy = 0.4203125, average_loss = 1.9955362, global_step = 10, loss = 255.42863
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmpfs/tmp/tmp_1kb1itp/model.ckpt-10
INFO:tensorflow:Loss for final step: 97.367035.
({'accuracy': 0.4203125,
  'average_loss': 1.9955362,
  'loss': 255.42863,
  'global_step': 10},
 [])

TensorFlow 2: Evaluating a Keras model

In TensorFlow 2, if you use the Keras Model.fit API for training, you can evaluate the model with tf.keras.utils.SidecarEvaluator. You can also visualize the evaluation metrics in TensorBoard which is not shown in this guide.

To help demonstrate this, let's first start by defining and training the model:

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
  ])

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model = create_model()
model.compile(optimizer='adam',
              loss=loss,
              metrics=['accuracy'],
              steps_per_execution=10,
              run_eagerly=True)

log_dir = tempfile.mkdtemp()
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=os.path.join(log_dir, 'ckpt-{epoch}'),
    save_weights_only=True)

model.fit(x=x_train,
          y=y_train,
          epochs=1,
          callbacks=[model_checkpoint])
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1707877387.723609   12475 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
1875/1875 [==============================] - 34s 18ms/step - loss: 0.2228 - accuracy: 0.9345
<keras.src.callbacks.History at 0x7f79501fdbe0>

Then, evaluate the model using tf.keras.utils.SidecarEvaluator. In real training, it's recommended to use a separate job to conduct the evaluation to free up worker resources for training.

data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
data = data.batch(64)

tf.keras.utils.SidecarEvaluator(
    model=model,
    data=data,
    checkpoint_dir=log_dir,
    max_evaluations=1
).start()
INFO:tensorflow:Waiting for new checkpoint at /tmpfs/tmp/tmptdfoc3kp
INFO:tensorflow:Found new checkpoint at /tmpfs/tmp/tmptdfoc3kp/ckpt-1
INFO:tensorflow:Evaluation starts: Model weights loaded from latest checkpoint file /tmpfs/tmp/tmptdfoc3kp/ckpt-1
157/157 - 2s - loss: 0.1062 - accuracy: 0.9672 - 2s/epoch - 11ms/step
INFO:tensorflow:End of evaluation. Metrics: loss=0.10622197389602661 accuracy=0.967199981212616
INFO:tensorflow:Last checkpoint evaluated. SidecarEvaluator stops.

Next steps

  • To learn more about sidecar evaluation consider reading the tf.keras.utils.SidecarEvaluator API docs.
  • To consider alternating training and evaluation in Keras consider reading about other built-in methods.