Thanks for tuning in to Google I/O. View all sessions on demandWatch on demand

Migrate early stopping

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This notebook demonstrates how you can set up model training with early stopping, first, in TensorFlow 1 with tf.estimator.Estimator and an early stopping hook, and then, in TensorFlow 2 with Keras APIs or a custom training loop. Early stopping is a regularization technique that stops training if, for example, the validation loss reaches a certain threshold.

In TensorFlow 2, there are three ways to implement early stopping:

Setup

import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
2023-01-21 02:20:59.919521: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-01-21 02:20:59.919649: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2023-01-21 02:20:59.919660: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

TensorFlow 1: Early stopping with an early stopping hook and tf.estimator

Start by defining functions for MNIST dataset loading and preprocessing, and model definition to be used with tf.estimator.Estimator:

def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def _input_fn():
  ds_train = tfds.load(
    name='mnist',
    split='train',
    shuffle_files=True,
    as_supervised=True)

  ds_train = ds_train.map(
      normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_train = ds_train.batch(128)
  ds_train = ds_train.repeat(100)
  return ds_train

def _eval_input_fn():
  ds_test = tfds.load(
    name='mnist',
    split='test',
    shuffle_files=True,
    as_supervised=True)
  ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_test = ds_test.batch(128)
  return ds_test

def _model_fn(features, labels, mode):
  flatten = tf1.layers.Flatten()(features)
  features = tf1.layers.Dense(128, 'relu')(flatten)
  logits = tf1.layers.Dense(10)(features)

  loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  optimizer = tf1.train.AdagradOptimizer(0.005)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())

  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

In TensorFlow 1, early stopping works by setting up an early stopping hook with tf.estimator.experimental.make_early_stopping_hook. You pass the hook to the make_early_stopping_hook method as a parameter for should_stop_fn, which can accept a function without any arguments. The training stops once should_stop_fn returns True.

The following example demonstrates how to implement an early stopping technique that limits the training time to a maximum of 20 seconds:

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

start_time = time.time()
max_train_seconds = 20

def should_stop_fn():
  return time.time() - start_time > max_train_seconds

early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
    estimator=estimator,
    should_stop_fn=should_stop_fn,
    run_every_secs=1,
    run_every_steps=None)

train_spec = tf1.estimator.TrainSpec(
    input_fn=_input_fn,
    hooks=[early_stopping_hook])

eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)

tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpfmt3ke80
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpfmt3ke80', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpfmt3ke80/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpfmt3ke80/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.3043585, step = 0
INFO:tensorflow:loss = 2.3043585, step = 0
INFO:tensorflow:global_step/sec: 144.825
INFO:tensorflow:global_step/sec: 144.825
INFO:tensorflow:loss = 1.3497735, step = 100 (0.693 sec)
INFO:tensorflow:loss = 1.3497735, step = 100 (0.693 sec)
INFO:tensorflow:global_step/sec: 154.727
INFO:tensorflow:global_step/sec: 154.727
INFO:tensorflow:loss = 0.83353037, step = 200 (0.646 sec)
INFO:tensorflow:loss = 0.83353037, step = 200 (0.646 sec)
INFO:tensorflow:global_step/sec: 160.674
INFO:tensorflow:global_step/sec: 160.674
INFO:tensorflow:loss = 0.69762164, step = 300 (0.622 sec)
INFO:tensorflow:loss = 0.69762164, step = 300 (0.622 sec)
INFO:tensorflow:global_step/sec: 151.078
INFO:tensorflow:global_step/sec: 151.078
INFO:tensorflow:loss = 0.6154139, step = 400 (0.662 sec)
INFO:tensorflow:loss = 0.6154139, step = 400 (0.662 sec)
INFO:tensorflow:global_step/sec: 373.93
INFO:tensorflow:global_step/sec: 373.93
INFO:tensorflow:loss = 0.50315976, step = 500 (0.267 sec)
INFO:tensorflow:loss = 0.50315976, step = 500 (0.267 sec)
INFO:tensorflow:global_step/sec: 517.621
INFO:tensorflow:global_step/sec: 517.621
INFO:tensorflow:loss = 0.4667952, step = 600 (0.193 sec)
INFO:tensorflow:loss = 0.4667952, step = 600 (0.193 sec)
INFO:tensorflow:global_step/sec: 556.511
INFO:tensorflow:global_step/sec: 556.511
INFO:tensorflow:loss = 0.38532913, step = 700 (0.180 sec)
INFO:tensorflow:loss = 0.38532913, step = 700 (0.180 sec)
INFO:tensorflow:global_step/sec: 557.793
INFO:tensorflow:global_step/sec: 557.793
INFO:tensorflow:loss = 0.51158535, step = 800 (0.179 sec)
INFO:tensorflow:loss = 0.51158535, step = 800 (0.179 sec)
INFO:tensorflow:global_step/sec: 553.041
INFO:tensorflow:global_step/sec: 553.041
INFO:tensorflow:loss = 0.402203, step = 900 (0.181 sec)
INFO:tensorflow:loss = 0.402203, step = 900 (0.181 sec)
INFO:tensorflow:global_step/sec: 517.58
INFO:tensorflow:global_step/sec: 517.58
INFO:tensorflow:loss = 0.44644293, step = 1000 (0.193 sec)
INFO:tensorflow:loss = 0.44644293, step = 1000 (0.193 sec)
INFO:tensorflow:global_step/sec: 544.171
INFO:tensorflow:global_step/sec: 544.171
INFO:tensorflow:loss = 0.43521973, step = 1100 (0.184 sec)
INFO:tensorflow:loss = 0.43521973, step = 1100 (0.184 sec)
INFO:tensorflow:global_step/sec: 546.196
INFO:tensorflow:global_step/sec: 546.196
INFO:tensorflow:loss = 0.38868037, step = 1200 (0.183 sec)
INFO:tensorflow:loss = 0.38868037, step = 1200 (0.183 sec)
INFO:tensorflow:global_step/sec: 545.356
INFO:tensorflow:global_step/sec: 545.356
INFO:tensorflow:loss = 0.49474105, step = 1300 (0.183 sec)
INFO:tensorflow:loss = 0.49474105, step = 1300 (0.183 sec)
INFO:tensorflow:global_step/sec: 533.842
INFO:tensorflow:global_step/sec: 533.842
INFO:tensorflow:loss = 0.2977963, step = 1400 (0.188 sec)
INFO:tensorflow:loss = 0.2977963, step = 1400 (0.188 sec)
INFO:tensorflow:global_step/sec: 381.327
INFO:tensorflow:global_step/sec: 381.327
INFO:tensorflow:loss = 0.29854277, step = 1500 (0.262 sec)
INFO:tensorflow:loss = 0.29854277, step = 1500 (0.262 sec)
INFO:tensorflow:global_step/sec: 394.517
INFO:tensorflow:global_step/sec: 394.517
INFO:tensorflow:loss = 0.4113207, step = 1600 (0.253 sec)
INFO:tensorflow:loss = 0.4113207, step = 1600 (0.253 sec)
INFO:tensorflow:global_step/sec: 405.687
INFO:tensorflow:global_step/sec: 405.687
INFO:tensorflow:loss = 0.38139153, step = 1700 (0.247 sec)
INFO:tensorflow:loss = 0.38139153, step = 1700 (0.247 sec)
INFO:tensorflow:global_step/sec: 569.242
INFO:tensorflow:global_step/sec: 569.242
INFO:tensorflow:loss = 0.32674897, step = 1800 (0.176 sec)
INFO:tensorflow:loss = 0.32674897, step = 1800 (0.176 sec)
INFO:tensorflow:global_step/sec: 500.692
INFO:tensorflow:global_step/sec: 500.692
INFO:tensorflow:loss = 0.5069632, step = 1900 (0.199 sec)
INFO:tensorflow:loss = 0.5069632, step = 1900 (0.199 sec)
INFO:tensorflow:global_step/sec: 499.402
INFO:tensorflow:global_step/sec: 499.402
INFO:tensorflow:loss = 0.21091068, step = 2000 (0.200 sec)
INFO:tensorflow:loss = 0.21091068, step = 2000 (0.200 sec)
INFO:tensorflow:global_step/sec: 551.063
INFO:tensorflow:global_step/sec: 551.063
INFO:tensorflow:loss = 0.29082954, step = 2100 (0.181 sec)
INFO:tensorflow:loss = 0.29082954, step = 2100 (0.181 sec)
INFO:tensorflow:global_step/sec: 550.014
INFO:tensorflow:global_step/sec: 550.014
INFO:tensorflow:loss = 0.3142205, step = 2200 (0.182 sec)
INFO:tensorflow:loss = 0.3142205, step = 2200 (0.182 sec)
INFO:tensorflow:global_step/sec: 554.924
INFO:tensorflow:global_step/sec: 554.924
INFO:tensorflow:loss = 0.354444, step = 2300 (0.181 sec)
INFO:tensorflow:loss = 0.354444, step = 2300 (0.181 sec)
INFO:tensorflow:global_step/sec: 515.242
INFO:tensorflow:global_step/sec: 515.242
INFO:tensorflow:loss = 0.26232147, step = 2400 (0.194 sec)
INFO:tensorflow:loss = 0.26232147, step = 2400 (0.194 sec)
INFO:tensorflow:global_step/sec: 560.106
INFO:tensorflow:global_step/sec: 560.106
INFO:tensorflow:loss = 0.2223329, step = 2500 (0.179 sec)
INFO:tensorflow:loss = 0.2223329, step = 2500 (0.179 sec)
INFO:tensorflow:global_step/sec: 491.955
INFO:tensorflow:global_step/sec: 491.955
INFO:tensorflow:loss = 0.16123277, step = 2600 (0.203 sec)
INFO:tensorflow:loss = 0.16123277, step = 2600 (0.203 sec)
INFO:tensorflow:global_step/sec: 528.073
INFO:tensorflow:global_step/sec: 528.073
INFO:tensorflow:loss = 0.31785184, step = 2700 (0.189 sec)
INFO:tensorflow:loss = 0.31785184, step = 2700 (0.189 sec)
INFO:tensorflow:global_step/sec: 519.496
INFO:tensorflow:global_step/sec: 519.496
INFO:tensorflow:loss = 0.48224646, step = 2800 (0.193 sec)
INFO:tensorflow:loss = 0.48224646, step = 2800 (0.193 sec)
INFO:tensorflow:global_step/sec: 482.332
INFO:tensorflow:global_step/sec: 482.332
INFO:tensorflow:loss = 0.2396833, step = 2900 (0.207 sec)
INFO:tensorflow:loss = 0.2396833, step = 2900 (0.207 sec)
INFO:tensorflow:global_step/sec: 509.391
INFO:tensorflow:global_step/sec: 509.391
INFO:tensorflow:loss = 0.3452201, step = 3000 (0.196 sec)
INFO:tensorflow:loss = 0.3452201, step = 3000 (0.196 sec)
INFO:tensorflow:global_step/sec: 503.545
INFO:tensorflow:global_step/sec: 503.545
INFO:tensorflow:loss = 0.20264056, step = 3100 (0.199 sec)
INFO:tensorflow:loss = 0.20264056, step = 3100 (0.199 sec)
INFO:tensorflow:global_step/sec: 520.823
INFO:tensorflow:global_step/sec: 520.823
INFO:tensorflow:loss = 0.40972877, step = 3200 (0.192 sec)
INFO:tensorflow:loss = 0.40972877, step = 3200 (0.192 sec)
INFO:tensorflow:global_step/sec: 520.413
INFO:tensorflow:global_step/sec: 520.413
INFO:tensorflow:loss = 0.3450895, step = 3300 (0.192 sec)
INFO:tensorflow:loss = 0.3450895, step = 3300 (0.192 sec)
INFO:tensorflow:global_step/sec: 557.182
INFO:tensorflow:global_step/sec: 557.182
INFO:tensorflow:loss = 0.26727343, step = 3400 (0.180 sec)
INFO:tensorflow:loss = 0.26727343, step = 3400 (0.180 sec)
INFO:tensorflow:global_step/sec: 566.333
INFO:tensorflow:global_step/sec: 566.333
INFO:tensorflow:loss = 0.2099815, step = 3500 (0.176 sec)
INFO:tensorflow:loss = 0.2099815, step = 3500 (0.176 sec)
INFO:tensorflow:global_step/sec: 569.768
INFO:tensorflow:global_step/sec: 569.768
INFO:tensorflow:loss = 0.23265842, step = 3600 (0.175 sec)
INFO:tensorflow:loss = 0.23265842, step = 3600 (0.175 sec)
INFO:tensorflow:global_step/sec: 552.542
INFO:tensorflow:global_step/sec: 552.542
INFO:tensorflow:loss = 0.28018022, step = 3700 (0.181 sec)
INFO:tensorflow:loss = 0.28018022, step = 3700 (0.181 sec)
INFO:tensorflow:global_step/sec: 521.827
INFO:tensorflow:global_step/sec: 521.827
INFO:tensorflow:loss = 0.37057906, step = 3800 (0.192 sec)
INFO:tensorflow:loss = 0.37057906, step = 3800 (0.192 sec)
INFO:tensorflow:global_step/sec: 542.979
INFO:tensorflow:global_step/sec: 542.979
INFO:tensorflow:loss = 0.2255877, step = 3900 (0.184 sec)
INFO:tensorflow:loss = 0.2255877, step = 3900 (0.184 sec)
INFO:tensorflow:global_step/sec: 550.828
INFO:tensorflow:global_step/sec: 550.828
INFO:tensorflow:loss = 0.30667296, step = 4000 (0.182 sec)
INFO:tensorflow:loss = 0.30667296, step = 4000 (0.182 sec)
INFO:tensorflow:global_step/sec: 544.601
INFO:tensorflow:global_step/sec: 544.601
INFO:tensorflow:loss = 0.21357399, step = 4100 (0.183 sec)
INFO:tensorflow:loss = 0.21357399, step = 4100 (0.183 sec)
INFO:tensorflow:global_step/sec: 555.71
INFO:tensorflow:global_step/sec: 555.71
INFO:tensorflow:loss = 0.2577279, step = 4200 (0.180 sec)
INFO:tensorflow:loss = 0.2577279, step = 4200 (0.180 sec)
INFO:tensorflow:global_step/sec: 446.746
INFO:tensorflow:global_step/sec: 446.746
INFO:tensorflow:loss = 0.29612792, step = 4300 (0.223 sec)
INFO:tensorflow:loss = 0.29612792, step = 4300 (0.223 sec)
INFO:tensorflow:global_step/sec: 546.035
INFO:tensorflow:global_step/sec: 546.035
INFO:tensorflow:loss = 0.31375682, step = 4400 (0.183 sec)
INFO:tensorflow:loss = 0.31375682, step = 4400 (0.183 sec)
INFO:tensorflow:global_step/sec: 510.825
INFO:tensorflow:global_step/sec: 510.825
INFO:tensorflow:loss = 0.28181487, step = 4500 (0.196 sec)
INFO:tensorflow:loss = 0.28181487, step = 4500 (0.196 sec)
INFO:tensorflow:global_step/sec: 557.316
INFO:tensorflow:global_step/sec: 557.316
INFO:tensorflow:loss = 0.31139958, step = 4600 (0.180 sec)
INFO:tensorflow:loss = 0.31139958, step = 4600 (0.180 sec)
INFO:tensorflow:global_step/sec: 505.992
INFO:tensorflow:global_step/sec: 505.992
INFO:tensorflow:loss = 0.15156099, step = 4700 (0.197 sec)
INFO:tensorflow:loss = 0.15156099, step = 4700 (0.197 sec)
INFO:tensorflow:global_step/sec: 489.912
INFO:tensorflow:global_step/sec: 489.912
INFO:tensorflow:loss = 0.2878572, step = 4800 (0.204 sec)
INFO:tensorflow:loss = 0.2878572, step = 4800 (0.204 sec)
INFO:tensorflow:global_step/sec: 555.884
INFO:tensorflow:global_step/sec: 555.884
INFO:tensorflow:loss = 0.38200092, step = 4900 (0.180 sec)
INFO:tensorflow:loss = 0.38200092, step = 4900 (0.180 sec)
INFO:tensorflow:global_step/sec: 521.052
INFO:tensorflow:global_step/sec: 521.052
INFO:tensorflow:loss = 0.27179155, step = 5000 (0.192 sec)
INFO:tensorflow:loss = 0.27179155, step = 5000 (0.192 sec)
INFO:tensorflow:global_step/sec: 541.078
INFO:tensorflow:global_step/sec: 541.078
INFO:tensorflow:loss = 0.33555904, step = 5100 (0.185 sec)
INFO:tensorflow:loss = 0.33555904, step = 5100 (0.185 sec)
INFO:tensorflow:global_step/sec: 516.851
INFO:tensorflow:global_step/sec: 516.851
INFO:tensorflow:loss = 0.21896324, step = 5200 (0.194 sec)
INFO:tensorflow:loss = 0.21896324, step = 5200 (0.194 sec)
INFO:tensorflow:global_step/sec: 559.338
INFO:tensorflow:global_step/sec: 559.338
INFO:tensorflow:loss = 0.25059485, step = 5300 (0.179 sec)
INFO:tensorflow:loss = 0.25059485, step = 5300 (0.179 sec)
INFO:tensorflow:global_step/sec: 464.363
INFO:tensorflow:global_step/sec: 464.363
INFO:tensorflow:loss = 0.15162958, step = 5400 (0.215 sec)
INFO:tensorflow:loss = 0.15162958, step = 5400 (0.215 sec)
INFO:tensorflow:global_step/sec: 563.826
INFO:tensorflow:global_step/sec: 563.826
INFO:tensorflow:loss = 0.2413958, step = 5500 (0.177 sec)
INFO:tensorflow:loss = 0.2413958, step = 5500 (0.177 sec)
INFO:tensorflow:global_step/sec: 561.753
INFO:tensorflow:global_step/sec: 561.753
INFO:tensorflow:loss = 0.18371066, step = 5600 (0.179 sec)
INFO:tensorflow:loss = 0.18371066, step = 5600 (0.179 sec)
INFO:tensorflow:global_step/sec: 502.932
INFO:tensorflow:global_step/sec: 502.932
INFO:tensorflow:loss = 0.17668869, step = 5700 (0.199 sec)
INFO:tensorflow:loss = 0.17668869, step = 5700 (0.199 sec)
INFO:tensorflow:Requesting early stopping at global step 5768
INFO:tensorflow:Requesting early stopping at global step 5768
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5769...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5769...
INFO:tensorflow:Saving checkpoints for 5769 into /tmpfs/tmp/tmpfmt3ke80/model.ckpt.
INFO:tensorflow:Saving checkpoints for 5769 into /tmpfs/tmp/tmpfmt3ke80/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5769...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5769...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2023-01-21T02:21:22
INFO:tensorflow:Starting evaluation at 2023-01-21T02:21:22
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpfmt3ke80/model.ckpt-5769
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpfmt3ke80/model.ckpt-5769
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Inference Time : 1.06422s
INFO:tensorflow:Inference Time : 1.06422s
INFO:tensorflow:Finished evaluation at 2023-01-21-02:21:23
INFO:tensorflow:Finished evaluation at 2023-01-21-02:21:23
INFO:tensorflow:Saving dict for global step 5769: global_step = 5769, loss = 0.24380085
INFO:tensorflow:Saving dict for global step 5769: global_step = 5769, loss = 0.24380085
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5769: /tmpfs/tmp/tmpfmt3ke80/model.ckpt-5769
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5769: /tmpfs/tmp/tmpfmt3ke80/model.ckpt-5769
INFO:tensorflow:Loss for final step: 0.22827312.
INFO:tensorflow:Loss for final step: 0.22827312.
({'loss': 0.24380085, 'global_step': 5769}, [])

TensorFlow 2: Early stopping with a built-in callback and Model.fit

Prepare the MNIST dataset and a simple Keras model:

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.005),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

In TensorFlow 2, when you use the built-in Keras Model.fit (or Model.evaluate), you can configure early stopping by passing a built-in callback—tf.keras.callbacks.EarlyStopping—to the callbacks parameter of Model.fit.

The EarlyStopping callback monitors a user-specified metric and ends training when it stops improving. (Check the Training and evaluation with the built-in methods or the API docs for more information.)

Below is an example of an early stopping callback that monitors the loss and stops training after the number of epochs that show no improvements is set to 3 (patience):

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)

# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)

len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 3s 5ms/step - loss: 0.2276 - sparse_categorical_accuracy: 0.9324 - val_loss: 0.1130 - val_sparse_categorical_accuracy: 0.9663
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0958 - sparse_categorical_accuracy: 0.9711 - val_loss: 0.1038 - val_sparse_categorical_accuracy: 0.9689
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0658 - sparse_categorical_accuracy: 0.9800 - val_loss: 0.1069 - val_sparse_categorical_accuracy: 0.9677
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0486 - sparse_categorical_accuracy: 0.9847 - val_loss: 0.0962 - val_sparse_categorical_accuracy: 0.9730
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0393 - sparse_categorical_accuracy: 0.9873 - val_loss: 0.1036 - val_sparse_categorical_accuracy: 0.9726
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0378 - sparse_categorical_accuracy: 0.9872 - val_loss: 0.1097 - val_sparse_categorical_accuracy: 0.9736
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0333 - sparse_categorical_accuracy: 0.9891 - val_loss: 0.1141 - val_sparse_categorical_accuracy: 0.9726
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0300 - sparse_categorical_accuracy: 0.9904 - val_loss: 0.1099 - val_sparse_categorical_accuracy: 0.9756
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0216 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1128 - val_sparse_categorical_accuracy: 0.9754
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0234 - sparse_categorical_accuracy: 0.9926 - val_loss: 0.1548 - val_sparse_categorical_accuracy: 0.9711
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0236 - sparse_categorical_accuracy: 0.9921 - val_loss: 0.1405 - val_sparse_categorical_accuracy: 0.9720
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0235 - sparse_categorical_accuracy: 0.9921 - val_loss: 0.1290 - val_sparse_categorical_accuracy: 0.9752
12

TensorFlow 2: Early stopping with a custom callback and Model.fit

You can also implement a custom early stopping callback, which can also be passed to the callbacks parameter of Model.fit (or Model.evaluate).

In this example, the training process is stopped once self.model.stop_training is set to be True:

class LimitTrainingTime(tf.keras.callbacks.Callback):
  def __init__(self, max_time_s):
    super().__init__()
    self.max_time_s = max_time_s
    self.start_time = None

  def on_train_begin(self, logs):
    self.start_time = time.time()

  def on_train_batch_end(self, batch, logs):
    now = time.time()
    if now - self.start_time >  self.max_time_s:
      self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0216 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1371 - val_sparse_categorical_accuracy: 0.9749
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0207 - sparse_categorical_accuracy: 0.9934 - val_loss: 0.1495 - val_sparse_categorical_accuracy: 0.9738
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0213 - sparse_categorical_accuracy: 0.9929 - val_loss: 0.1516 - val_sparse_categorical_accuracy: 0.9758
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9955 - val_loss: 0.1699 - val_sparse_categorical_accuracy: 0.9743
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0177 - sparse_categorical_accuracy: 0.9942 - val_loss: 0.1528 - val_sparse_categorical_accuracy: 0.9767
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0183 - sparse_categorical_accuracy: 0.9944 - val_loss: 0.1708 - val_sparse_categorical_accuracy: 0.9744
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0164 - sparse_categorical_accuracy: 0.9950 - val_loss: 0.1825 - val_sparse_categorical_accuracy: 0.9760
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0130 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.1519 - val_sparse_categorical_accuracy: 0.9799
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9949 - val_loss: 0.1943 - val_sparse_categorical_accuracy: 0.9755
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0184 - sparse_categorical_accuracy: 0.9947 - val_loss: 0.2014 - val_sparse_categorical_accuracy: 0.9763
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0158 - sparse_categorical_accuracy: 0.9953 - val_loss: 0.1790 - val_sparse_categorical_accuracy: 0.9783
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0157 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1899 - val_sparse_categorical_accuracy: 0.9766
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0130 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2224 - val_sparse_categorical_accuracy: 0.9740
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0173 - sparse_categorical_accuracy: 0.9955 - val_loss: 0.2038 - val_sparse_categorical_accuracy: 0.9754
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0166 - sparse_categorical_accuracy: 0.9955 - val_loss: 0.2048 - val_sparse_categorical_accuracy: 0.9761
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0086 - sparse_categorical_accuracy: 0.9972 - val_loss: 0.2055 - val_sparse_categorical_accuracy: 0.9756
Epoch 17/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0133 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2759 - val_sparse_categorical_accuracy: 0.9735
Epoch 18/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.2361 - val_sparse_categorical_accuracy: 0.9763
Epoch 19/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0134 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2329 - val_sparse_categorical_accuracy: 0.9759
Epoch 20/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.2216 - val_sparse_categorical_accuracy: 0.9780
Epoch 21/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0153 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2341 - val_sparse_categorical_accuracy: 0.9775
Epoch 22/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0111 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.2499 - val_sparse_categorical_accuracy: 0.9785
Epoch 23/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0124 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.2589 - val_sparse_categorical_accuracy: 0.9761
Epoch 24/100
469/469 [==============================] - 1s 2ms/step - loss: 0.0092 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.2432 - val_sparse_categorical_accuracy: 0.9761
24

TensorFlow 2: Early stopping with a custom training loop

In TensorFlow 2, you can implement early stopping in a custom training loop if you're not training and evaluating with the built-in Keras methods.

Start by using Keras APIs to define another simple model, an optimizer, a loss function, and metrics:

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()

Define the parameter update functions with tf.GradientTape and the @tf.function decorator for a speedup:

@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
      logits = model(x, training=True)
      loss_value = loss_fn(y, logits)
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
  train_acc_metric.update_state(y, logits)
  train_loss_metric.update_state(y, logits)
  return loss_value

@tf.function
def test_step(x, y):
  logits = model(x, training=False)
  val_acc_metric.update_state(y, logits)
  val_loss_metric.update_state(y, logits)

Next, write a custom training loop, where you can implement your early stopping rule manually.

The example below shows how to stop training when the validation loss doesn't improve over a certain number of epochs:

epochs = 100
patience = 5
wait = 0
best = float('inf')

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
      loss_value = train_step(x_batch_train, y_batch_train)
      if step % 200 == 0:
        print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
        print("Seen so far: %s samples" % ((step + 1) * 128))        
    train_acc = train_acc_metric.result()
    train_loss = train_loss_metric.result()
    train_acc_metric.reset_states()
    train_loss_metric.reset_states()
    print("Training acc over epoch: %.4f" % (train_acc.numpy()))

    for x_batch_val, y_batch_val in ds_test:
      test_step(x_batch_val, y_batch_val)
    val_acc = val_acc_metric.result()
    val_loss = val_loss_metric.result()
    val_acc_metric.reset_states()
    val_loss_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

    # The early stopping strategy: stop the training if `val_loss` does not
    # decrease over a certain number of epochs.
    wait += 1
    if val_loss < best:
      best = val_loss
      wait = 0
    if wait >= patience:
      break
Start of epoch 0
Training loss at step 0: 2.3311
Seen so far: 128 samples
Training loss at step 200: 0.2537
Seen so far: 25728 samples
Training loss at step 400: 0.1951
Seen so far: 51328 samples
Training acc over epoch: 0.9314
Validation acc: 0.9633
Time taken: 2.06s

Start of epoch 1
Training loss at step 0: 0.0846
Seen so far: 128 samples
Training loss at step 200: 0.1519
Seen so far: 25728 samples
Training loss at step 400: 0.1162
Seen so far: 51328 samples
Training acc over epoch: 0.9691
Validation acc: 0.9686
Time taken: 1.06s

Start of epoch 2
Training loss at step 0: 0.0470
Seen so far: 128 samples
Training loss at step 200: 0.0991
Seen so far: 25728 samples
Training loss at step 400: 0.0771
Seen so far: 51328 samples
Training acc over epoch: 0.9790
Validation acc: 0.9660
Time taken: 1.04s

Start of epoch 3
Training loss at step 0: 0.0474
Seen so far: 128 samples
Training loss at step 200: 0.0519
Seen so far: 25728 samples
Training loss at step 400: 0.0462
Seen so far: 51328 samples
Training acc over epoch: 0.9829
Validation acc: 0.9638
Time taken: 1.03s

Start of epoch 4
Training loss at step 0: 0.0320
Seen so far: 128 samples
Training loss at step 200: 0.0360
Seen so far: 25728 samples
Training loss at step 400: 0.0182
Seen so far: 51328 samples
Training acc over epoch: 0.9853
Validation acc: 0.9657
Time taken: 1.03s

Start of epoch 5
Training loss at step 0: 0.0278
Seen so far: 128 samples
Training loss at step 200: 0.0686
Seen so far: 25728 samples
Training loss at step 400: 0.0664
Seen so far: 51328 samples
Training acc over epoch: 0.9868
Validation acc: 0.9659
Time taken: 1.04s

Start of epoch 6
Training loss at step 0: 0.0167
Seen so far: 128 samples
Training loss at step 200: 0.0385
Seen so far: 25728 samples
Training loss at step 400: 0.0482
Seen so far: 51328 samples
Training acc over epoch: 0.9887
Validation acc: 0.9737
Time taken: 0.98s

Start of epoch 7
Training loss at step 0: 0.0147
Seen so far: 128 samples
Training loss at step 200: 0.0165
Seen so far: 25728 samples
Training loss at step 400: 0.0097
Seen so far: 51328 samples
Training acc over epoch: 0.9899
Validation acc: 0.9731
Time taken: 1.01s

Start of epoch 8
Training loss at step 0: 0.0041
Seen so far: 128 samples
Training loss at step 200: 0.0058
Seen so far: 25728 samples
Training loss at step 400: 0.0252
Seen so far: 51328 samples
Training acc over epoch: 0.9905
Validation acc: 0.9679
Time taken: 1.03s

Start of epoch 9
Training loss at step 0: 0.0100
Seen so far: 128 samples
Training loss at step 200: 0.0199
Seen so far: 25728 samples
Training loss at step 400: 0.0175
Seen so far: 51328 samples
Training acc over epoch: 0.9923
Validation acc: 0.9733
Time taken: 1.00s

Start of epoch 10
Training loss at step 0: 0.0068
Seen so far: 128 samples
Training loss at step 200: 0.0332
Seen so far: 25728 samples
Training loss at step 400: 0.0410
Seen so far: 51328 samples
Training acc over epoch: 0.9911
Validation acc: 0.9739
Time taken: 1.07s

Start of epoch 11
Training loss at step 0: 0.0028
Seen so far: 128 samples
Training loss at step 200: 0.0451
Seen so far: 25728 samples
Training loss at step 400: 0.0141
Seen so far: 51328 samples
Training acc over epoch: 0.9925
Validation acc: 0.9734
Time taken: 0.99s

Start of epoch 12
Training loss at step 0: 0.0010
Seen so far: 128 samples
Training loss at step 200: 0.0115
Seen so far: 25728 samples
Training loss at step 400: 0.0929
Seen so far: 51328 samples
Training acc over epoch: 0.9933
Validation acc: 0.9727
Time taken: 0.99s

Start of epoch 13
Training loss at step 0: 0.0189
Seen so far: 128 samples
Training loss at step 200: 0.0242
Seen so far: 25728 samples
Training loss at step 400: 0.0257
Seen so far: 51328 samples
Training acc over epoch: 0.9938
Validation acc: 0.9720
Time taken: 1.07s

Start of epoch 14
Training loss at step 0: 0.0009
Seen so far: 128 samples
Training loss at step 200: 0.0633
Seen so far: 25728 samples
Training loss at step 400: 0.0052
Seen so far: 51328 samples
Training acc over epoch: 0.9936
Validation acc: 0.9727
Time taken: 1.00s

Start of epoch 15
Training loss at step 0: 0.0097
Seen so far: 128 samples
Training loss at step 200: 0.0048
Seen so far: 25728 samples
Training loss at step 400: 0.0323
Seen so far: 51328 samples
Training acc over epoch: 0.9940
Validation acc: 0.9730
Time taken: 1.01s

Start of epoch 16
Training loss at step 0: 0.0166
Seen so far: 128 samples
Training loss at step 200: 0.0132
Seen so far: 25728 samples
Training loss at step 400: 0.0085
Seen so far: 51328 samples
Training acc over epoch: 0.9942
Validation acc: 0.9750
Time taken: 1.00s

Start of epoch 17
Training loss at step 0: 0.0008
Seen so far: 128 samples
Training loss at step 200: 0.0027
Seen so far: 25728 samples
Training loss at step 400: 0.0224
Seen so far: 51328 samples
Training acc over epoch: 0.9945
Validation acc: 0.9748
Time taken: 1.03s

Start of epoch 18
Training loss at step 0: 0.0231
Seen so far: 128 samples
Training loss at step 200: 0.0143
Seen so far: 25728 samples
Training loss at step 400: 0.0177
Seen so far: 51328 samples
Training acc over epoch: 0.9951
Validation acc: 0.9765
Time taken: 1.05s

Start of epoch 19
Training loss at step 0: 0.0026
Seen so far: 128 samples
Training loss at step 200: 0.0316
Seen so far: 25728 samples
Training loss at step 400: 0.1067
Seen so far: 51328 samples
Training acc over epoch: 0.9944
Validation acc: 0.9760
Time taken: 1.02s

Start of epoch 20
Training loss at step 0: 0.0001
Seen so far: 128 samples
Training loss at step 200: 0.0022
Seen so far: 25728 samples
Training loss at step 400: 0.0137
Seen so far: 51328 samples
Training acc over epoch: 0.9948
Validation acc: 0.9720
Time taken: 1.04s

Start of epoch 21
Training loss at step 0: 0.0018
Seen so far: 128 samples
Training loss at step 200: 0.0508
Seen so far: 25728 samples
Training loss at step 400: 0.0885
Seen so far: 51328 samples
Training acc over epoch: 0.9946
Validation acc: 0.9744
Time taken: 1.06s

Start of epoch 22
Training loss at step 0: 0.0559
Seen so far: 128 samples
Training loss at step 200: 0.0095
Seen so far: 25728 samples
Training loss at step 400: 0.0086
Seen so far: 51328 samples
Training acc over epoch: 0.9954
Validation acc: 0.9741
Time taken: 1.07s

Start of epoch 23
Training loss at step 0: 0.0979
Seen so far: 128 samples
Training loss at step 200: 0.0015
Seen so far: 25728 samples
Training loss at step 400: 0.0300
Seen so far: 51328 samples
Training acc over epoch: 0.9963
Validation acc: 0.9730
Time taken: 1.04s

Start of epoch 24
Training loss at step 0: 0.0020
Seen so far: 128 samples
Training loss at step 200: 0.0551
Seen so far: 25728 samples
Training loss at step 400: 0.0254
Seen so far: 51328 samples
Training acc over epoch: 0.9958
Validation acc: 0.9762
Time taken: 1.05s

Start of epoch 25
Training loss at step 0: 0.0002
Seen so far: 128 samples
Training loss at step 200: 0.0230
Seen so far: 25728 samples
Training loss at step 400: 0.0410
Seen so far: 51328 samples
Training acc over epoch: 0.9962
Validation acc: 0.9718
Time taken: 1.09s

Start of epoch 26
Training loss at step 0: 0.0001
Seen so far: 128 samples
Training loss at step 200: 0.0016
Seen so far: 25728 samples
Training loss at step 400: 0.0001
Seen so far: 51328 samples
Training acc over epoch: 0.9957
Validation acc: 0.9730
Time taken: 1.03s

Start of epoch 27
Training loss at step 0: 0.0008
Seen so far: 128 samples
Training loss at step 200: 0.0919
Seen so far: 25728 samples
Training loss at step 400: 0.0024
Seen so far: 51328 samples
Training acc over epoch: 0.9956
Validation acc: 0.9726
Time taken: 1.04s

Next steps