![]() |
![]() |
![]() |
![]() |
Evaluation is a critical part of measuring and benchmarking models.
This guide demonstrates how to migrate evaluator tasks from TensorFlow 1 to TensorFlow 2. In Tensorflow 1 this functionality is implemented by tf.estimator.train_and_evaluate
, when the API is running distributedly. In Tensorflow 2, you can use the built-in tf.keras.utils.SidecarEvaluator
, or a custom evaluation loop on the evaluator task.
There are simple serial evaluation options in both TensorFlow 1 (tf.estimator.Estimator.evaluate
) and TensorFlow 2 (Model.fit(..., validation_data=(...))
or Model.evaluate
). The evaluator task is preferable when you would like your workers not switching between training and evaluation, and built-in evaluation in Model.fit
is preferable when you would like your evaluation to be distributed.
Setup
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
import time
import os
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
TensorFlow 1: Evaluating using tf.estimator.train_and_evaluate
In TensorFlow 1, you can configure a tf.estimator
to evaluate the estimator using tf.estimator.train_and_evaluate
.
In this example, start by defining the tf.estimator.Estimator
and speciyfing training and evaluation specifications:
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
test_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_test},
y=y_test.astype(np.int32),
num_epochs=10,
shuffle=False
)
train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
steps=10,
throttle_secs=0)
Then, train and evaluate the model. The evaluation runs synchronously between training because it's limited as a local run in this notebook and alternates between training and evaluation. However, if the estimator is used distributedly, the evaluator will run as a dedicated evaluator task. For more information, check the migration guide on distributed training.
tf1.estimator.train_and_evaluate(estimator=classifier,
train_spec=train_spec,
eval_spec=eval_spec)
TensorFlow 2: Evaluating a Keras model
In TensorFlow 2, if you use the Keras Model.fit
API for training, you can evaluate the model with tf.keras.utils.SidecarEvaluator
. You can also visualize the evaluation metrics in TensorBoard which is not shown in this guide.
To help demonstrate this, let's first start by defining and training the model:
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = create_model()
model.compile(optimizer='adam',
loss=loss,
metrics=['accuracy'],
steps_per_execution=10,
run_eagerly=True)
log_dir = tempfile.mkdtemp()
model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(log_dir, 'ckpt-{epoch}'),
save_weights_only=True)
model.fit(x=x_train,
y=y_train,
epochs=1,
callbacks=[model_checkpoint])
Then, evaluate the model using tf.keras.utils.SidecarEvaluator
. In real training, it's recommended to use a separate job to conduct the evaluation to free up worker resources for training.
data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
data = data.batch(64)
tf.keras.utils.SidecarEvaluator(
model=model,
data=data,
checkpoint_dir=log_dir,
max_evaluations=1
).start()
Next steps
- To learn more about sidecar evaluation consider reading the
tf.keras.utils.SidecarEvaluator
API docs. - To consider alternating training and evaluation in Keras consider reading about other built-in methods.