![]() |
![]() |
![]() |
![]() |
This guide demonstrates how to migrate from TensorFlow 1's tf.estimator.Estimator
APIs to TensorFlow 2's tf.keras
APIs. First, you will set up and run a basic model for training and evaluation with tf.estimator.Estimator
. Then, you will perform the equivalent steps in TensorFlow 2 with the tf.keras
APIs. You will also learn how to customize the training step by subclassing tf.keras.Model
and using tf.GradientTape
.
- In TensorFlow 1, the high-level
tf.estimator.Estimator
APIs let you train and evaluate a model, as well as perform inference and save your model (for serving). - In TensorFlow 2, use the Keras APIs to perform the aforementioned tasks, such as model building, gradient application, training, evaluation, and prediction.
(For migrating model/checkpoint saving workflows to TensorFlow 2, check out the SavedModel and Checkpoint migration guides.)
Setup
Start with imports and a simple dataset:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
2022-12-14 03:49:25.998812: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 03:49:25.998922: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 03:49:25.998932: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]
TensorFlow 1: Train and evaluate with tf.estimator.Estimator
This example shows how to perform training and evaluation with tf.estimator.Estimator
in TensorFlow 1.
Start by defining a few functions: an input function for the training data, an evaluation input function for the evaluation data, and a model function that tells the Estimator
how the training op is defined with the features and labels:
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
def _eval_input_fn():
return tf1.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).batch(1)
def _model_fn(features, labels, mode):
logits = tf1.layers.Dense(1)(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
Instantiate your Estimator
, and train the model:
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
estimator.train(_input_fn)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp4lir836r INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp4lir836r', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp4lir836r/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 2.1985006, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3... INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmp4lir836r/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3... INFO:tensorflow:Loss for final step: 8.682148. <tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f22203fc100>
Evaluate the program with the evaluation set:
estimator.evaluate(_eval_input_fn)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T03:49:31 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp4lir836r/model.ckpt-3 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Inference Time : 0.26095s INFO:tensorflow:Finished evaluation at 2022-12-14-03:49:31 INFO:tensorflow:Saving dict for global step 3: global_step = 3, loss = 20.376455 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmpfs/tmp/tmp4lir836r/model.ckpt-3 {'loss': 20.376455, 'global_step': 3}
TensorFlow 2: Train and evaluate with the built-in Keras methods
This example demonstrates how to perform training and evaluation with Keras Model.fit
and Model.evaluate
in TensorFlow 2. (You can learn more in the Training and evaluation with the built-in methods guide.)
- Start by preparing the dataset pipeline with the
tf.data.Dataset
APIs. - Define a simple Keras Sequential model with one linear (
tf.keras.layers.Dense
) layer. - Instantiate an Adagrad optimizer (
tf.keras.optimizers.Adagrad
). - Configure the model for training by passing the
optimizer
variable and the mean-squared error ("mse"
) loss toModel.compile
.
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).batch(1)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer=optimizer, loss="mse")
With that, you are ready to train the model by calling Model.fit
:
model.fit(dataset)
3/3 [==============================] - 0s 5ms/step - loss: 8.5447 <keras.callbacks.History at 0x7f21097bb160>
Finally, evaluate the model with Model.evaluate
:
model.evaluate(eval_dataset, return_dict=True)
3/3 [==============================] - 0s 3ms/step - loss: 26.3376 {'loss': 26.337629318237305}
TensorFlow 2: Train and evaluate with a custom training step and built-in Keras methods
In TensorFlow 2, you can also write your own custom training step function with tf.GradientTape
to perform forward and backward passes, while still taking advantage of the built-in training support, such as tf.keras.callbacks.Callback
and tf.distribute.Strategy
. (Learn more in Customizing what happens in Model.fit and Writing custom training loops from scratch.)
In this example, start by creating a custom tf.keras.Model
by subclassing tf.keras.Sequential
that overrides Model.train_step
. (Learn more about subclassing tf.keras.Model). Inside that class, define a custom train_step
function that for each batch of data performs a forward pass and backward pass during one training step.
class CustomModel(tf.keras.Sequential):
"""A custom sequential model that overrides `Model.train_step`."""
def train_step(self, data):
batch_data, labels = data
with tf.GradientTape() as tape:
predictions = self(batch_data, training=True)
# Compute the loss value (the loss function is configured
# in `Model.compile`).
loss = self.compiled_loss(labels, predictions)
# Compute the gradients of the parameters with respect to the loss.
gradients = tape.gradient(loss, self.trainable_variables)
# Perform gradient descent by updating the weights/parameters.
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
# Update the metrics (includes the metric that tracks the loss).
self.compiled_metrics.update_state(labels, predictions)
# Return a dict mapping metric names to the current values.
return {m.name: m.result() for m in self.metrics}
Next, as before:
- Prepare the dataset pipeline with
tf.data.Dataset
. - Define a simple model with one
tf.keras.layers.Dense
layer. - Instantiate Adagrad (
tf.keras.optimizers.Adagrad
) - Configure the model for training with
Model.compile
, while using mean-squared error ("mse"
) as the loss function.
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).batch(1)
model = CustomModel([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer=optimizer, loss="mse")
Call Model.fit
to train the model:
model.fit(dataset)
3/3 [==============================] - 0s 3ms/step - loss: 5.6254 <keras.callbacks.History at 0x7f2119e93cd0>
And, finally, evaluate the program with Model.evaluate
:
model.evaluate(eval_dataset, return_dict=True)
3/3 [==============================] - 0s 3ms/step - loss: 15.1705 {'loss': 15.170487403869629}
Next steps
Additional Keras resources you may find useful:
- Guide: Training and evaluation with the built-in methods
- Guide: Customize what happens in Model.fit
- Guide: Writing a training loop from scratch
- Guide: Making new Keras layers and models via subclassing
The following guides can assist with migrating distribution strategy workflows from tf.estimator
APIs: