ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

Migrate SessionRunHook to Keras callbacks

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

In TensorFlow 1, to customize the behavior of training, you use tf.estimator.SessionRunHook with tf.estimator.Estimator. This guide demonstrates how to migrate from SessionRunHook to TensorFlow 2's custom callbacks with the tf.keras.callbacks.Callback API, which works with Keras Model.fit for training (as well as Model.evaluate and Model.predict). You will learn how to do this by implementing a SessionRunHook and a Callback task that measures examples per second during training.

Examples of callbacks are checkpoint saving (tf.keras.callbacks.ModelCheckpoint) and TensorBoard summary writing. Keras callbacks are objects that are called at different points during training/evaluation/prediction in the built-in Keras Model.fit/Model.evaluate/Model.predict APIs. You can learn more about callbacks in the tf.keras.callbacks.Callback API docs, as well as the Writing your own callbacks and Training and evaluation with the built-in methods (the Using callbacks section) guides.

Setup

Start with imports and a simple dataset for demonstration purposes:

import tensorflow as tf
import tensorflow.compat.v1 as tf1

import time
from datetime import datetime
from absl import flags
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]

TensorFlow 1: Create a custom SessionRunHook with tf.estimator APIs

The following TensorFlow 1 examples show how to set up a custom SessionRunHook that measures examples per second during training. After creating the hook (LoggerHook), pass it to the hooks parameter of tf.estimator.Estimator.train.

def _input_fn():
  return tf1.data.Dataset.from_tensor_slices(
      (features, labels)).batch(1).repeat(100)

def _model_fn(features, labels, mode):
  logits = tf1.layers.Dense(1)(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
class LoggerHook(tf1.train.SessionRunHook):
  """Logs loss and runtime."""

  def begin(self):
    self._step = -1
    self._start_time = time.time()
    self.log_frequency = 10

  def before_run(self, run_context):
    self._step += 1

  def after_run(self, run_context, run_values):
    if self._step % self.log_frequency == 0:
      current_time = time.time()
      duration = current_time - self._start_time
      self._start_time = current_time
      examples_per_sec = self.log_frequency / duration
      print('Time:', datetime.now(), ', Step #:', self._step,
            ', Examples per second:', examples_per_sec)

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

# Begin training.
estimator.train(_input_fn, hooks=[LoggerHook()])
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpd3kkest0
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpd3kkest0', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpd3kkest0/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
Time: 2021-09-22 20:03:54.632194 , Step #: 0 , Examples per second: 6.10470928029035
INFO:tensorflow:loss = 0.4186823, step = 0
Time: 2021-09-22 20:03:54.667973 , Step #: 10 , Examples per second: 279.48611333226717
Time: 2021-09-22 20:03:54.677735 , Step #: 20 , Examples per second: 1024.350119669809
Time: 2021-09-22 20:03:54.687424 , Step #: 30 , Examples per second: 1032.0883880016734
Time: 2021-09-22 20:03:54.697468 , Step #: 40 , Examples per second: 995.4914200270572
Time: 2021-09-22 20:03:54.707063 , Step #: 50 , Examples per second: 1042.2443654797107
Time: 2021-09-22 20:03:54.717058 , Step #: 60 , Examples per second: 1000.525750817013
Time: 2021-09-22 20:03:54.726516 , Step #: 70 , Examples per second: 1057.4053345434377
Time: 2021-09-22 20:03:54.736013 , Step #: 80 , Examples per second: 1052.8135746379176
Time: 2021-09-22 20:03:54.745926 , Step #: 90 , Examples per second: 1008.7796430804752
INFO:tensorflow:global_step/sec: 800.052
Time: 2021-09-22 20:03:54.758401 , Step #: 100 , Examples per second: 801.6329651007225
INFO:tensorflow:loss = 6.092345e-05, step = 100 (0.126 sec)
Time: 2021-09-22 20:03:54.769678 , Step #: 110 , Examples per second: 886.7637793610859
Time: 2021-09-22 20:03:54.779611 , Step #: 120 , Examples per second: 1006.7940470475276
Time: 2021-09-22 20:03:54.789567 , Step #: 130 , Examples per second: 1004.407193658852
Time: 2021-09-22 20:03:54.799209 , Step #: 140 , Examples per second: 1037.064583127287
Time: 2021-09-22 20:03:54.809019 , Step #: 150 , Examples per second: 1019.3710202692849
Time: 2021-09-22 20:03:54.818400 , Step #: 160 , Examples per second: 1065.951001321541
Time: 2021-09-22 20:03:54.828427 , Step #: 170 , Examples per second: 997.3614876111666
Time: 2021-09-22 20:03:54.837958 , Step #: 180 , Examples per second: 1049.126791565572
Time: 2021-09-22 20:03:54.847210 , Step #: 190 , Examples per second: 1080.866898595542
INFO:tensorflow:global_step/sec: 989.503
Time: 2021-09-22 20:03:54.859304 , Step #: 200 , Examples per second: 826.887469442473
INFO:tensorflow:loss = 0.00023960586, step = 200 (0.101 sec)
Time: 2021-09-22 20:03:54.871152 , Step #: 210 , Examples per second: 843.9922730199613
Time: 2021-09-22 20:03:54.881449 , Step #: 220 , Examples per second: 971.1509875199704
Time: 2021-09-22 20:03:54.891077 , Step #: 230 , Examples per second: 1038.6568272992918
Time: 2021-09-22 20:03:54.900433 , Step #: 240 , Examples per second: 1068.803098641796
Time: 2021-09-22 20:03:54.909864 , Step #: 250 , Examples per second: 1060.4262634945517
Time: 2021-09-22 20:03:54.919740 , Step #: 260 , Examples per second: 1012.578822847762
Time: 2021-09-22 20:03:54.929341 , Step #: 270 , Examples per second: 1041.4679810294738
Time: 2021-09-22 20:03:54.939518 , Step #: 280 , Examples per second: 983.2397205682404
Time: 2021-09-22 20:03:54.950539 , Step #: 290 , Examples per second: 906.9157585192874
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 300...
INFO:tensorflow:Saving checkpoints for 300 into /tmp/tmpd3kkest0/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 300...
INFO:tensorflow:Loss for final step: 9.510103e-05.
<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f08d9f857d0>

TensorFlow 2: Create a custom Keras callback for Model.fit

In TensorFlow 2, when you use the built-in Keras Model.fit (or Model.evaluate) for training/evaluation, you can configure a custom tf.keras.callbacks.Callback, which you then pass to the callbacks parameter of Model.fit (or Model.evaluate). (Learn more in the Writing your own callbacks guide.)

In the example below, you will write a custom tf.keras.callbacks.Callback that logs various metrics—it will measure examples per second, which should be comparable to the metrics in the previous SessionRunHook example.

class CustomCallback(tf.keras.callbacks.Callback):

    def on_train_begin(self, logs = None):
      self._step = -1
      self._start_time = time.time()
      self.log_frequency = 10

    def on_train_batch_begin(self, batch, logs = None):
      self._step += 1

    def on_train_batch_end(self, batch, logs = None):
      if self._step % self.log_frequency == 0:
        current_time = time.time()
        duration = current_time - self._start_time
        self._start_time = current_time
        examples_per_sec = self.log_frequency / duration
        print('Time:', datetime.now(), ', Step #:', self._step,
              ', Examples per second:', examples_per_sec)

callback = CustomCallback()

dataset = tf.data.Dataset.from_tensor_slices(
    (features, labels)).batch(1).repeat(100)

model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)

model.compile(optimizer, "mse")

# Begin training.
result = model.fit(dataset, callbacks=[callback], verbose = 0)
# Provide the results of training metrics.
result.history
Time: 2021-09-22 20:03:55.292902 , Step #: 0 , Examples per second: 41.7177640740004
Time: 2021-09-22 20:03:55.307630 , Step #: 10 , Examples per second: 678.931658519214
Time: 2021-09-22 20:03:55.321877 , Step #: 20 , Examples per second: 701.916827043762
Time: 2021-09-22 20:03:55.335955 , Step #: 30 , Examples per second: 710.2970364098222
Time: 2021-09-22 20:03:55.349107 , Step #: 40 , Examples per second: 760.3473342639088
Time: 2021-09-22 20:03:55.363413 , Step #: 50 , Examples per second: 699.0273657544749
Time: 2021-09-22 20:03:55.376974 , Step #: 60 , Examples per second: 737.395218002813
Time: 2021-09-22 20:03:55.390098 , Step #: 70 , Examples per second: 761.9357651504142
Time: 2021-09-22 20:03:55.403143 , Step #: 80 , Examples per second: 766.5869795664729
Time: 2021-09-22 20:03:55.416655 , Step #: 90 , Examples per second: 740.1408177310346
Time: 2021-09-22 20:03:55.430573 , Step #: 100 , Examples per second: 718.4118664679787
Time: 2021-09-22 20:03:55.444679 , Step #: 110 , Examples per second: 708.9763353617309
Time: 2021-09-22 20:03:55.458107 , Step #: 120 , Examples per second: 744.7140498215585
Time: 2021-09-22 20:03:55.472363 , Step #: 130 , Examples per second: 701.4355475282628
Time: 2021-09-22 20:03:55.485659 , Step #: 140 , Examples per second: 752.1661316643653
Time: 2021-09-22 20:03:55.499408 , Step #: 150 , Examples per second: 727.1804296190988
Time: 2021-09-22 20:03:55.513545 , Step #: 160 , Examples per second: 707.4814877287678
Time: 2021-09-22 20:03:55.526756 , Step #: 170 , Examples per second: 756.9033096329448
Time: 2021-09-22 20:03:55.540358 , Step #: 180 , Examples per second: 735.1205832865957
Time: 2021-09-22 20:03:55.554744 , Step #: 190 , Examples per second: 695.2038719087715
Time: 2021-09-22 20:03:55.568133 , Step #: 200 , Examples per second: 746.8357044924413
Time: 2021-09-22 20:03:55.581976 , Step #: 210 , Examples per second: 722.3961006527618
Time: 2021-09-22 20:03:55.595447 , Step #: 220 , Examples per second: 742.3678295191065
Time: 2021-09-22 20:03:55.609489 , Step #: 230 , Examples per second: 712.1301232639479
Time: 2021-09-22 20:03:55.623511 , Step #: 240 , Examples per second: 713.1593355210583
Time: 2021-09-22 20:03:55.637077 , Step #: 250 , Examples per second: 737.1878515185601
Time: 2021-09-22 20:03:55.650548 , Step #: 260 , Examples per second: 742.2627285115118
Time: 2021-09-22 20:03:55.664177 , Step #: 270 , Examples per second: 733.7573912739232
Time: 2021-09-22 20:03:55.677369 , Step #: 280 , Examples per second: 758.0524127959516
Time: 2021-09-22 20:03:55.690806 , Step #: 290 , Examples per second: 744.1987224982257
{'loss': [1.5562355518341064]}

Next steps

Learn more about callbacks in:

You may also find the following migration-related resources useful: