Migrate early stopping

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This notebook demonstrates how you can set up model training with early stopping, first, in TensorFlow 1 with tf.estimator.Estimator and an early stopping hook, and then, in TensorFlow 2 with Keras APIs or a custom training loop. Early stopping is a regularization technique that stops training if, for example, the validation loss reaches a certain threshold.

In TensorFlow 2, there are three ways to implement early stopping:

Setup

import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
2023-10-04 01:37:29.125012: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-04 01:37:29.125061: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-04 01:37:29.125095: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

TensorFlow 1: Early stopping with an early stopping hook and tf.estimator

Start by defining functions for MNIST dataset loading and preprocessing, and model definition to be used with tf.estimator.Estimator:

def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def _input_fn():
  ds_train = tfds.load(
    name='mnist',
    split='train',
    shuffle_files=True,
    as_supervised=True)

  ds_train = ds_train.map(
      normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_train = ds_train.batch(128)
  ds_train = ds_train.repeat(100)
  return ds_train

def _eval_input_fn():
  ds_test = tfds.load(
    name='mnist',
    split='test',
    shuffle_files=True,
    as_supervised=True)
  ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_test = ds_test.batch(128)
  return ds_test

def _model_fn(features, labels, mode):
  flatten = tf1.layers.Flatten()(features)
  features = tf1.layers.Dense(128, 'relu')(flatten)
  logits = tf1.layers.Dense(10)(features)

  loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  optimizer = tf1.train.AdagradOptimizer(0.005)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())

  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

In TensorFlow 1, early stopping works by setting up an early stopping hook with tf.estimator.experimental.make_early_stopping_hook. You pass the hook to the make_early_stopping_hook method as a parameter for should_stop_fn, which can accept a function without any arguments. The training stops once should_stop_fn returns True.

The following example demonstrates how to implement an early stopping technique that limits the training time to a maximum of 20 seconds:

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

start_time = time.time()
max_train_seconds = 20

def should_stop_fn():
  return time.time() - start_time > max_train_seconds

early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
    estimator=estimator,
    should_stop_fn=should_stop_fn,
    run_every_secs=1,
    run_every_steps=None)

train_spec = tf1.estimator.TrainSpec(
    input_fn=_input_fn,
    hooks=[early_stopping_hook])

eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)

tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:1: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpp8x_ipb8
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpp8x_ipb8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:9: make_early_stopping_hook (from tensorflow_estimator.python.estimator.early_stopping) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/early_stopping.py:474: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:15: TrainSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:19: EvalSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:21: train_and_evaluate (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
2023-10-04 01:37:32.147816: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1468818800.py:37: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1468818800.py:37: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
2023-10-04 01:37:34.401540: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:loss = 2.2961025, step = 0
INFO:tensorflow:loss = 2.2961025, step = 0
INFO:tensorflow:global_step/sec: 335.823
INFO:tensorflow:global_step/sec: 335.823
INFO:tensorflow:loss = 1.318093, step = 100 (0.300 sec)
INFO:tensorflow:loss = 1.318093, step = 100 (0.300 sec)
INFO:tensorflow:global_step/sec: 423.864
INFO:tensorflow:global_step/sec: 423.864
INFO:tensorflow:loss = 0.8593272, step = 200 (0.236 sec)
INFO:tensorflow:loss = 0.8593272, step = 200 (0.236 sec)
INFO:tensorflow:global_step/sec: 436.472
INFO:tensorflow:global_step/sec: 436.472
INFO:tensorflow:loss = 0.74179834, step = 300 (0.229 sec)
INFO:tensorflow:loss = 0.74179834, step = 300 (0.229 sec)
INFO:tensorflow:global_step/sec: 432.337
INFO:tensorflow:global_step/sec: 432.337
INFO:tensorflow:loss = 0.64878374, step = 400 (0.232 sec)
INFO:tensorflow:loss = 0.64878374, step = 400 (0.232 sec)
INFO:tensorflow:global_step/sec: 425.537
INFO:tensorflow:global_step/sec: 425.537
INFO:tensorflow:loss = 0.49600804, step = 500 (0.234 sec)
INFO:tensorflow:loss = 0.49600804, step = 500 (0.234 sec)
INFO:tensorflow:global_step/sec: 505.423
INFO:tensorflow:global_step/sec: 505.423
INFO:tensorflow:loss = 0.44587734, step = 600 (0.198 sec)
INFO:tensorflow:loss = 0.44587734, step = 600 (0.198 sec)
INFO:tensorflow:global_step/sec: 518.525
INFO:tensorflow:global_step/sec: 518.525
INFO:tensorflow:loss = 0.38416976, step = 700 (0.193 sec)
INFO:tensorflow:loss = 0.38416976, step = 700 (0.193 sec)
INFO:tensorflow:global_step/sec: 509.36
INFO:tensorflow:global_step/sec: 509.36
INFO:tensorflow:loss = 0.5128687, step = 800 (0.197 sec)
INFO:tensorflow:loss = 0.5128687, step = 800 (0.197 sec)
INFO:tensorflow:global_step/sec: 531.176
INFO:tensorflow:global_step/sec: 531.176
INFO:tensorflow:loss = 0.38511667, step = 900 (0.188 sec)
INFO:tensorflow:loss = 0.38511667, step = 900 (0.188 sec)
INFO:tensorflow:global_step/sec: 465.699
INFO:tensorflow:global_step/sec: 465.699
INFO:tensorflow:loss = 0.43401116, step = 1000 (0.215 sec)
INFO:tensorflow:loss = 0.43401116, step = 1000 (0.215 sec)
INFO:tensorflow:global_step/sec: 495.799
INFO:tensorflow:global_step/sec: 495.799
INFO:tensorflow:loss = 0.44691753, step = 1100 (0.203 sec)
INFO:tensorflow:loss = 0.44691753, step = 1100 (0.203 sec)
INFO:tensorflow:global_step/sec: 504.403
INFO:tensorflow:global_step/sec: 504.403
INFO:tensorflow:loss = 0.40272003, step = 1200 (0.198 sec)
INFO:tensorflow:loss = 0.40272003, step = 1200 (0.198 sec)
INFO:tensorflow:global_step/sec: 482.232
INFO:tensorflow:global_step/sec: 482.232
INFO:tensorflow:loss = 0.47271937, step = 1300 (0.207 sec)
INFO:tensorflow:loss = 0.47271937, step = 1300 (0.207 sec)
INFO:tensorflow:global_step/sec: 501.553
INFO:tensorflow:global_step/sec: 501.553
INFO:tensorflow:loss = 0.29635084, step = 1400 (0.200 sec)
INFO:tensorflow:loss = 0.29635084, step = 1400 (0.200 sec)
INFO:tensorflow:global_step/sec: 459.337
INFO:tensorflow:global_step/sec: 459.337
INFO:tensorflow:loss = 0.293486, step = 1500 (0.218 sec)
INFO:tensorflow:loss = 0.293486, step = 1500 (0.218 sec)
INFO:tensorflow:global_step/sec: 511.454
INFO:tensorflow:global_step/sec: 511.454
INFO:tensorflow:loss = 0.40195698, step = 1600 (0.196 sec)
INFO:tensorflow:loss = 0.40195698, step = 1600 (0.196 sec)
INFO:tensorflow:global_step/sec: 487.452
INFO:tensorflow:global_step/sec: 487.452
INFO:tensorflow:loss = 0.38753498, step = 1700 (0.206 sec)
INFO:tensorflow:loss = 0.38753498, step = 1700 (0.206 sec)
INFO:tensorflow:global_step/sec: 500.287
INFO:tensorflow:global_step/sec: 500.287
INFO:tensorflow:loss = 0.3344679, step = 1800 (0.199 sec)
INFO:tensorflow:loss = 0.3344679, step = 1800 (0.199 sec)
INFO:tensorflow:global_step/sec: 476.868
INFO:tensorflow:global_step/sec: 476.868
INFO:tensorflow:loss = 0.49753922, step = 1900 (0.210 sec)
INFO:tensorflow:loss = 0.49753922, step = 1900 (0.210 sec)
INFO:tensorflow:global_step/sec: 484.689
INFO:tensorflow:global_step/sec: 484.689
INFO:tensorflow:loss = 0.21684857, step = 2000 (0.207 sec)
INFO:tensorflow:loss = 0.21684857, step = 2000 (0.207 sec)
INFO:tensorflow:global_step/sec: 496.691
INFO:tensorflow:global_step/sec: 496.691
INFO:tensorflow:loss = 0.28068116, step = 2100 (0.202 sec)
INFO:tensorflow:loss = 0.28068116, step = 2100 (0.202 sec)
INFO:tensorflow:global_step/sec: 499.682
INFO:tensorflow:global_step/sec: 499.682
INFO:tensorflow:loss = 0.3000077, step = 2200 (0.200 sec)
INFO:tensorflow:loss = 0.3000077, step = 2200 (0.200 sec)
INFO:tensorflow:global_step/sec: 507.391
INFO:tensorflow:global_step/sec: 507.391
INFO:tensorflow:loss = 0.34870982, step = 2300 (0.197 sec)
INFO:tensorflow:loss = 0.34870982, step = 2300 (0.197 sec)
INFO:tensorflow:global_step/sec: 475.414
INFO:tensorflow:global_step/sec: 475.414
INFO:tensorflow:loss = 0.24876948, step = 2400 (0.211 sec)
INFO:tensorflow:loss = 0.24876948, step = 2400 (0.211 sec)
INFO:tensorflow:global_step/sec: 499.39
INFO:tensorflow:global_step/sec: 499.39
INFO:tensorflow:loss = 0.21644332, step = 2500 (0.200 sec)
INFO:tensorflow:loss = 0.21644332, step = 2500 (0.200 sec)
INFO:tensorflow:global_step/sec: 500.659
INFO:tensorflow:global_step/sec: 500.659
INFO:tensorflow:loss = 0.152693, step = 2600 (0.200 sec)
INFO:tensorflow:loss = 0.152693, step = 2600 (0.200 sec)
INFO:tensorflow:global_step/sec: 500.201
INFO:tensorflow:global_step/sec: 500.201
INFO:tensorflow:loss = 0.33327985, step = 2700 (0.200 sec)
INFO:tensorflow:loss = 0.33327985, step = 2700 (0.200 sec)
INFO:tensorflow:global_step/sec: 517.488
INFO:tensorflow:global_step/sec: 517.488
INFO:tensorflow:loss = 0.47266263, step = 2800 (0.193 sec)
INFO:tensorflow:loss = 0.47266263, step = 2800 (0.193 sec)
INFO:tensorflow:global_step/sec: 466.55
INFO:tensorflow:global_step/sec: 466.55
INFO:tensorflow:loss = 0.24876104, step = 2900 (0.215 sec)
INFO:tensorflow:loss = 0.24876104, step = 2900 (0.215 sec)
INFO:tensorflow:global_step/sec: 499.818
INFO:tensorflow:global_step/sec: 499.818
INFO:tensorflow:loss = 0.33199376, step = 3000 (0.200 sec)
INFO:tensorflow:loss = 0.33199376, step = 3000 (0.200 sec)
INFO:tensorflow:global_step/sec: 477.253
INFO:tensorflow:global_step/sec: 477.253
INFO:tensorflow:loss = 0.19820198, step = 3100 (0.210 sec)
INFO:tensorflow:loss = 0.19820198, step = 3100 (0.210 sec)
INFO:tensorflow:global_step/sec: 512.079
INFO:tensorflow:global_step/sec: 512.079
INFO:tensorflow:loss = 0.4163157, step = 3200 (0.195 sec)
INFO:tensorflow:loss = 0.4163157, step = 3200 (0.195 sec)
INFO:tensorflow:global_step/sec: 461.596
INFO:tensorflow:global_step/sec: 461.596
INFO:tensorflow:loss = 0.3364423, step = 3300 (0.216 sec)
INFO:tensorflow:loss = 0.3364423, step = 3300 (0.216 sec)
INFO:tensorflow:global_step/sec: 488.092
INFO:tensorflow:global_step/sec: 488.092
INFO:tensorflow:loss = 0.25606278, step = 3400 (0.206 sec)
INFO:tensorflow:loss = 0.25606278, step = 3400 (0.206 sec)
INFO:tensorflow:global_step/sec: 490.26
INFO:tensorflow:global_step/sec: 490.26
INFO:tensorflow:loss = 0.20572862, step = 3500 (0.204 sec)
INFO:tensorflow:loss = 0.20572862, step = 3500 (0.204 sec)
INFO:tensorflow:global_step/sec: 487.323
INFO:tensorflow:global_step/sec: 487.323
INFO:tensorflow:loss = 0.2212799, step = 3600 (0.206 sec)
INFO:tensorflow:loss = 0.2212799, step = 3600 (0.206 sec)
INFO:tensorflow:global_step/sec: 492.262
INFO:tensorflow:global_step/sec: 492.262
INFO:tensorflow:loss = 0.282781, step = 3700 (0.203 sec)
INFO:tensorflow:loss = 0.282781, step = 3700 (0.203 sec)
INFO:tensorflow:global_step/sec: 449.227
INFO:tensorflow:global_step/sec: 449.227
INFO:tensorflow:loss = 0.365446, step = 3800 (0.223 sec)
INFO:tensorflow:loss = 0.365446, step = 3800 (0.223 sec)
INFO:tensorflow:global_step/sec: 521.924
INFO:tensorflow:global_step/sec: 521.924
INFO:tensorflow:loss = 0.22579709, step = 3900 (0.191 sec)
INFO:tensorflow:loss = 0.22579709, step = 3900 (0.191 sec)
INFO:tensorflow:global_step/sec: 535.844
INFO:tensorflow:global_step/sec: 535.844
INFO:tensorflow:loss = 0.30844557, step = 4000 (0.187 sec)
INFO:tensorflow:loss = 0.30844557, step = 4000 (0.187 sec)
INFO:tensorflow:global_step/sec: 496.551
INFO:tensorflow:global_step/sec: 496.551
INFO:tensorflow:loss = 0.21947613, step = 4100 (0.202 sec)
INFO:tensorflow:loss = 0.21947613, step = 4100 (0.202 sec)
INFO:tensorflow:global_step/sec: 508.086
INFO:tensorflow:global_step/sec: 508.086
INFO:tensorflow:loss = 0.26513258, step = 4200 (0.197 sec)
INFO:tensorflow:loss = 0.26513258, step = 4200 (0.197 sec)
INFO:tensorflow:global_step/sec: 462.149
INFO:tensorflow:global_step/sec: 462.149
INFO:tensorflow:loss = 0.29323363, step = 4300 (0.217 sec)
INFO:tensorflow:loss = 0.29323363, step = 4300 (0.217 sec)
INFO:tensorflow:global_step/sec: 503.226
INFO:tensorflow:global_step/sec: 503.226
INFO:tensorflow:loss = 0.31204918, step = 4400 (0.198 sec)
INFO:tensorflow:loss = 0.31204918, step = 4400 (0.198 sec)
INFO:tensorflow:global_step/sec: 510.874
INFO:tensorflow:global_step/sec: 510.874
INFO:tensorflow:loss = 0.26014802, step = 4500 (0.196 sec)
INFO:tensorflow:loss = 0.26014802, step = 4500 (0.196 sec)
INFO:tensorflow:global_step/sec: 511.332
INFO:tensorflow:global_step/sec: 511.332
INFO:tensorflow:loss = 0.33227044, step = 4600 (0.196 sec)
INFO:tensorflow:loss = 0.33227044, step = 4600 (0.196 sec)
INFO:tensorflow:global_step/sec: 472.525
INFO:tensorflow:global_step/sec: 472.525
INFO:tensorflow:loss = 0.13610935, step = 4700 (0.211 sec)
INFO:tensorflow:loss = 0.13610935, step = 4700 (0.211 sec)
INFO:tensorflow:global_step/sec: 505.262
INFO:tensorflow:global_step/sec: 505.262
INFO:tensorflow:loss = 0.28594193, step = 4800 (0.198 sec)
INFO:tensorflow:loss = 0.28594193, step = 4800 (0.198 sec)
INFO:tensorflow:global_step/sec: 515.453
INFO:tensorflow:global_step/sec: 515.453
INFO:tensorflow:loss = 0.38484165, step = 4900 (0.195 sec)
INFO:tensorflow:loss = 0.38484165, step = 4900 (0.195 sec)
INFO:tensorflow:global_step/sec: 479.259
INFO:tensorflow:global_step/sec: 479.259
INFO:tensorflow:loss = 0.27081215, step = 5000 (0.208 sec)
INFO:tensorflow:loss = 0.27081215, step = 5000 (0.208 sec)
INFO:tensorflow:global_step/sec: 491.361
INFO:tensorflow:global_step/sec: 491.361
INFO:tensorflow:loss = 0.33877313, step = 5100 (0.204 sec)
INFO:tensorflow:loss = 0.33877313, step = 5100 (0.204 sec)
INFO:tensorflow:global_step/sec: 471.256
INFO:tensorflow:global_step/sec: 471.256
INFO:tensorflow:loss = 0.2074028, step = 5200 (0.212 sec)
INFO:tensorflow:loss = 0.2074028, step = 5200 (0.212 sec)
INFO:tensorflow:global_step/sec: 506.672
INFO:tensorflow:global_step/sec: 506.672
INFO:tensorflow:loss = 0.24718614, step = 5300 (0.197 sec)
INFO:tensorflow:loss = 0.24718614, step = 5300 (0.197 sec)
INFO:tensorflow:global_step/sec: 491.333
INFO:tensorflow:global_step/sec: 491.333
INFO:tensorflow:loss = 0.16439602, step = 5400 (0.203 sec)
INFO:tensorflow:loss = 0.16439602, step = 5400 (0.203 sec)
INFO:tensorflow:global_step/sec: 500.638
INFO:tensorflow:global_step/sec: 500.638
INFO:tensorflow:loss = 0.22073755, step = 5500 (0.200 sec)
INFO:tensorflow:loss = 0.22073755, step = 5500 (0.200 sec)
INFO:tensorflow:global_step/sec: 508.206
INFO:tensorflow:global_step/sec: 508.206
INFO:tensorflow:loss = 0.18545151, step = 5600 (0.196 sec)
INFO:tensorflow:loss = 0.18545151, step = 5600 (0.196 sec)
INFO:tensorflow:global_step/sec: 450.634
INFO:tensorflow:global_step/sec: 450.634
INFO:tensorflow:loss = 0.17126478, step = 5700 (0.222 sec)
INFO:tensorflow:loss = 0.17126478, step = 5700 (0.222 sec)
INFO:tensorflow:global_step/sec: 525.95
INFO:tensorflow:global_step/sec: 525.95
INFO:tensorflow:loss = 0.2689212, step = 5800 (0.190 sec)
INFO:tensorflow:loss = 0.2689212, step = 5800 (0.190 sec)
INFO:tensorflow:global_step/sec: 483.04
INFO:tensorflow:global_step/sec: 483.04
INFO:tensorflow:loss = 0.21353054, step = 5900 (0.207 sec)
INFO:tensorflow:loss = 0.21353054, step = 5900 (0.207 sec)
INFO:tensorflow:global_step/sec: 480.064
INFO:tensorflow:global_step/sec: 480.064
INFO:tensorflow:loss = 0.25207376, step = 6000 (0.209 sec)
INFO:tensorflow:loss = 0.25207376, step = 6000 (0.209 sec)
INFO:tensorflow:global_step/sec: 451.147
INFO:tensorflow:global_step/sec: 451.147
INFO:tensorflow:loss = 0.18487616, step = 6100 (0.221 sec)
INFO:tensorflow:loss = 0.18487616, step = 6100 (0.221 sec)
INFO:tensorflow:global_step/sec: 490.621
INFO:tensorflow:global_step/sec: 490.621
INFO:tensorflow:loss = 0.267612, step = 6200 (0.204 sec)
INFO:tensorflow:loss = 0.267612, step = 6200 (0.204 sec)
INFO:tensorflow:global_step/sec: 499.641
INFO:tensorflow:global_step/sec: 499.641
INFO:tensorflow:loss = 0.24492177, step = 6300 (0.200 sec)
INFO:tensorflow:loss = 0.24492177, step = 6300 (0.200 sec)
INFO:tensorflow:global_step/sec: 498.738
INFO:tensorflow:global_step/sec: 498.738
INFO:tensorflow:loss = 0.28542638, step = 6400 (0.200 sec)
INFO:tensorflow:loss = 0.28542638, step = 6400 (0.200 sec)
INFO:tensorflow:global_step/sec: 523.196
INFO:tensorflow:global_step/sec: 523.196
INFO:tensorflow:loss = 0.26353425, step = 6500 (0.191 sec)
INFO:tensorflow:loss = 0.26353425, step = 6500 (0.191 sec)
INFO:tensorflow:global_step/sec: 455.054
INFO:tensorflow:global_step/sec: 455.054
INFO:tensorflow:loss = 0.19190696, step = 6600 (0.220 sec)
INFO:tensorflow:loss = 0.19190696, step = 6600 (0.220 sec)
INFO:tensorflow:global_step/sec: 522.506
INFO:tensorflow:global_step/sec: 522.506
INFO:tensorflow:loss = 0.25657248, step = 6700 (0.191 sec)
INFO:tensorflow:loss = 0.25657248, step = 6700 (0.191 sec)
INFO:tensorflow:global_step/sec: 506.766
INFO:tensorflow:global_step/sec: 506.766
INFO:tensorflow:loss = 0.39108855, step = 6800 (0.197 sec)
INFO:tensorflow:loss = 0.39108855, step = 6800 (0.197 sec)
INFO:tensorflow:global_step/sec: 481.896
INFO:tensorflow:global_step/sec: 481.896
INFO:tensorflow:loss = 0.15236983, step = 6900 (0.208 sec)
INFO:tensorflow:loss = 0.15236983, step = 6900 (0.208 sec)
INFO:tensorflow:global_step/sec: 506.833
INFO:tensorflow:global_step/sec: 506.833
INFO:tensorflow:loss = 0.3356074, step = 7000 (0.197 sec)
INFO:tensorflow:loss = 0.3356074, step = 7000 (0.197 sec)
INFO:tensorflow:global_step/sec: 467.501
INFO:tensorflow:global_step/sec: 467.501
INFO:tensorflow:loss = 0.1956916, step = 7100 (0.214 sec)
INFO:tensorflow:loss = 0.1956916, step = 7100 (0.214 sec)
INFO:tensorflow:global_step/sec: 494.573
INFO:tensorflow:global_step/sec: 494.573
INFO:tensorflow:loss = 0.24700633, step = 7200 (0.203 sec)
INFO:tensorflow:loss = 0.24700633, step = 7200 (0.203 sec)
INFO:tensorflow:global_step/sec: 522.627
INFO:tensorflow:global_step/sec: 522.627
INFO:tensorflow:loss = 0.17747968, step = 7300 (0.191 sec)
INFO:tensorflow:loss = 0.17747968, step = 7300 (0.191 sec)
INFO:tensorflow:global_step/sec: 486.106
INFO:tensorflow:global_step/sec: 486.106
INFO:tensorflow:loss = 0.28143722, step = 7400 (0.206 sec)
INFO:tensorflow:loss = 0.28143722, step = 7400 (0.206 sec)
INFO:tensorflow:global_step/sec: 479.817
INFO:tensorflow:global_step/sec: 479.817
INFO:tensorflow:loss = 0.14399035, step = 7500 (0.209 sec)
INFO:tensorflow:loss = 0.14399035, step = 7500 (0.209 sec)
INFO:tensorflow:global_step/sec: 487.355
INFO:tensorflow:global_step/sec: 487.355
INFO:tensorflow:loss = 0.2522769, step = 7600 (0.205 sec)
INFO:tensorflow:loss = 0.2522769, step = 7600 (0.205 sec)
INFO:tensorflow:global_step/sec: 526.322
INFO:tensorflow:global_step/sec: 526.322
INFO:tensorflow:loss = 0.25871283, step = 7700 (0.190 sec)
INFO:tensorflow:loss = 0.25871283, step = 7700 (0.190 sec)
INFO:tensorflow:global_step/sec: 501.783
INFO:tensorflow:global_step/sec: 501.783
INFO:tensorflow:loss = 0.16747415, step = 7800 (0.199 sec)
INFO:tensorflow:loss = 0.16747415, step = 7800 (0.199 sec)
INFO:tensorflow:global_step/sec: 485.409
INFO:tensorflow:global_step/sec: 485.409
INFO:tensorflow:loss = 0.15161222, step = 7900 (0.205 sec)
INFO:tensorflow:loss = 0.15161222, step = 7900 (0.205 sec)
INFO:tensorflow:global_step/sec: 462.017
INFO:tensorflow:global_step/sec: 462.017
INFO:tensorflow:loss = 0.18137535, step = 8000 (0.216 sec)
INFO:tensorflow:loss = 0.18137535, step = 8000 (0.216 sec)
INFO:tensorflow:global_step/sec: 488.439
INFO:tensorflow:global_step/sec: 488.439
INFO:tensorflow:loss = 0.19478491, step = 8100 (0.205 sec)
INFO:tensorflow:loss = 0.19478491, step = 8100 (0.205 sec)
INFO:tensorflow:global_step/sec: 485.238
INFO:tensorflow:global_step/sec: 485.238
INFO:tensorflow:loss = 0.27511528, step = 8200 (0.206 sec)
INFO:tensorflow:loss = 0.27511528, step = 8200 (0.206 sec)
INFO:tensorflow:Requesting early stopping at global step 8286
INFO:tensorflow:Requesting early stopping at global step 8286
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8287...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8287...
INFO:tensorflow:Saving checkpoints for 8287 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt.
INFO:tensorflow:Saving checkpoints for 8287 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8287...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8287...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2023-10-04T01:37:52
INFO:tensorflow:Starting evaluation at 2023-10-04T01:37:52
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287
2023-10-04 01:37:52.957202: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Inference Time : 0.51915s
INFO:tensorflow:Inference Time : 0.51915s
INFO:tensorflow:Finished evaluation at 2023-10-04-01:37:53
INFO:tensorflow:Finished evaluation at 2023-10-04-01:37:53
INFO:tensorflow:Saving dict for global step 8287: global_step = 8287, loss = 0.21179305
INFO:tensorflow:Saving dict for global step 8287: global_step = 8287, loss = 0.21179305
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8287: /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8287: /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287
INFO:tensorflow:Loss for final step: 0.28788015.
INFO:tensorflow:Loss for final step: 0.28788015.
({'loss': 0.21179305, 'global_step': 8287}, [])

TensorFlow 2: Early stopping with a built-in callback and Model.fit

Prepare the MNIST dataset and a simple Keras model:

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.005),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

In TensorFlow 2, when you use the built-in Keras Model.fit (or Model.evaluate), you can configure early stopping by passing a built-in callback—tf.keras.callbacks.EarlyStopping—to the callbacks parameter of Model.fit.

The EarlyStopping callback monitors a user-specified metric and ends training when it stops improving. (Check the Training and evaluation with the built-in methods or the API docs for more information.)

Below is an example of an early stopping callback that monitors the loss and stops training after the number of epochs that show no improvements is set to 3 (patience):

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)

# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)

len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 3s 4ms/step - loss: 0.2326 - sparse_categorical_accuracy: 0.9310 - val_loss: 0.1214 - val_sparse_categorical_accuracy: 0.9630
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.1004 - sparse_categorical_accuracy: 0.9700 - val_loss: 0.1009 - val_sparse_categorical_accuracy: 0.9684
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0689 - sparse_categorical_accuracy: 0.9792 - val_loss: 0.1061 - val_sparse_categorical_accuracy: 0.9674
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0528 - sparse_categorical_accuracy: 0.9835 - val_loss: 0.1244 - val_sparse_categorical_accuracy: 0.9640
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0424 - sparse_categorical_accuracy: 0.9866 - val_loss: 0.1001 - val_sparse_categorical_accuracy: 0.9730
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0355 - sparse_categorical_accuracy: 0.9883 - val_loss: 0.1034 - val_sparse_categorical_accuracy: 0.9728
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0315 - sparse_categorical_accuracy: 0.9894 - val_loss: 0.1029 - val_sparse_categorical_accuracy: 0.9748
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0319 - sparse_categorical_accuracy: 0.9888 - val_loss: 0.1267 - val_sparse_categorical_accuracy: 0.9707
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0280 - sparse_categorical_accuracy: 0.9905 - val_loss: 0.1128 - val_sparse_categorical_accuracy: 0.9733
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0281 - sparse_categorical_accuracy: 0.9907 - val_loss: 0.1246 - val_sparse_categorical_accuracy: 0.9718
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0270 - sparse_categorical_accuracy: 0.9905 - val_loss: 0.1184 - val_sparse_categorical_accuracy: 0.9786
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0217 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1231 - val_sparse_categorical_accuracy: 0.9756
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0191 - sparse_categorical_accuracy: 0.9935 - val_loss: 0.1466 - val_sparse_categorical_accuracy: 0.9716
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0220 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1475 - val_sparse_categorical_accuracy: 0.9754
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0200 - sparse_categorical_accuracy: 0.9934 - val_loss: 0.1472 - val_sparse_categorical_accuracy: 0.9746
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0195 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.1454 - val_sparse_categorical_accuracy: 0.9755
16

TensorFlow 2: Early stopping with a custom callback and Model.fit

You can also implement a custom early stopping callback, which can also be passed to the callbacks parameter of Model.fit (or Model.evaluate).

In this example, the training process is stopped once self.model.stop_training is set to be True:

class LimitTrainingTime(tf.keras.callbacks.Callback):
  def __init__(self, max_time_s):
    super().__init__()
    self.max_time_s = max_time_s
    self.start_time = None

  def on_train_begin(self, logs):
    self.start_time = time.time()

  def on_train_batch_end(self, batch, logs):
    now = time.time()
    if now - self.start_time >  self.max_time_s:
      self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0180 - sparse_categorical_accuracy: 0.9939 - val_loss: 0.1440 - val_sparse_categorical_accuracy: 0.9785
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0139 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1728 - val_sparse_categorical_accuracy: 0.9755
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0213 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.1567 - val_sparse_categorical_accuracy: 0.9787
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0198 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1826 - val_sparse_categorical_accuracy: 0.9740
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0146 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1658 - val_sparse_categorical_accuracy: 0.9756
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0126 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2150 - val_sparse_categorical_accuracy: 0.9726
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0174 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1890 - val_sparse_categorical_accuracy: 0.9769
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.1852 - val_sparse_categorical_accuracy: 0.9772
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.2051 - val_sparse_categorical_accuracy: 0.9761
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.2260 - val_sparse_categorical_accuracy: 0.9727
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0167 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.2008 - val_sparse_categorical_accuracy: 0.9757
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0133 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2283 - val_sparse_categorical_accuracy: 0.9755
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0106 - sparse_categorical_accuracy: 0.9969 - val_loss: 0.2270 - val_sparse_categorical_accuracy: 0.9739
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0139 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2169 - val_sparse_categorical_accuracy: 0.9775
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0121 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.2282 - val_sparse_categorical_accuracy: 0.9773
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2723 - val_sparse_categorical_accuracy: 0.9738
Epoch 17/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0118 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2223 - val_sparse_categorical_accuracy: 0.9784
Epoch 18/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0168 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2489 - val_sparse_categorical_accuracy: 0.9770
Epoch 19/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2607 - val_sparse_categorical_accuracy: 0.9753
Epoch 20/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0112 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.2267 - val_sparse_categorical_accuracy: 0.9776
Epoch 21/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0105 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.2758 - val_sparse_categorical_accuracy: 0.9733
21

TensorFlow 2: Early stopping with a custom training loop

In TensorFlow 2, you can implement early stopping in a custom training loop if you're not training and evaluating with the built-in Keras methods.

Start by using Keras APIs to define another simple model, an optimizer, a loss function, and metrics:

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()

Define the parameter update functions with tf.GradientTape and the @tf.function decorator for a speedup:

@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
      logits = model(x, training=True)
      loss_value = loss_fn(y, logits)
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
  train_acc_metric.update_state(y, logits)
  train_loss_metric.update_state(y, logits)
  return loss_value

@tf.function
def test_step(x, y):
  logits = model(x, training=False)
  val_acc_metric.update_state(y, logits)
  val_loss_metric.update_state(y, logits)

Next, write a custom training loop, where you can implement your early stopping rule manually.

The example below shows how to stop training when the validation loss doesn't improve over a certain number of epochs:

epochs = 100
patience = 5
wait = 0
best = float('inf')

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
      loss_value = train_step(x_batch_train, y_batch_train)
      if step % 200 == 0:
        print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
        print("Seen so far: %s samples" % ((step + 1) * 128))        
    train_acc = train_acc_metric.result()
    train_loss = train_loss_metric.result()
    train_acc_metric.reset_states()
    train_loss_metric.reset_states()
    print("Training acc over epoch: %.4f" % (train_acc.numpy()))

    for x_batch_val, y_batch_val in ds_test:
      test_step(x_batch_val, y_batch_val)
    val_acc = val_acc_metric.result()
    val_loss = val_loss_metric.result()
    val_acc_metric.reset_states()
    val_loss_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

    # The early stopping strategy: stop the training if `val_loss` does not
    # decrease over a certain number of epochs.
    wait += 1
    if val_loss < best:
      best = val_loss
      wait = 0
    if wait >= patience:
      break
Start of epoch 0
Training loss at step 0: 2.4644
Seen so far: 128 samples
Training loss at step 200: 0.2622
Seen so far: 25728 samples
Training loss at step 400: 0.2129
Seen so far: 51328 samples
Training acc over epoch: 0.9297
Validation acc: 0.9610
Time taken: 2.07s

Start of epoch 1
Training loss at step 0: 0.0756
Seen so far: 128 samples
Training loss at step 200: 0.1215
Seen so far: 25728 samples
Training loss at step 400: 0.1359
Seen so far: 51328 samples
Training acc over epoch: 0.9685
Validation acc: 0.9664
Time taken: 1.36s

Start of epoch 2
Training loss at step 0: 0.0390
Seen so far: 128 samples
Training loss at step 200: 0.0995
Seen so far: 25728 samples
Training loss at step 400: 0.1270
Seen so far: 51328 samples
Training acc over epoch: 0.9778
Validation acc: 0.9690
Time taken: 1.37s

Start of epoch 3
Training loss at step 0: 0.0282
Seen so far: 128 samples
Training loss at step 200: 0.0673
Seen so far: 25728 samples
Training loss at step 400: 0.0700
Seen so far: 51328 samples
Training acc over epoch: 0.9826
Validation acc: 0.9695
Time taken: 1.35s

Start of epoch 4
Training loss at step 0: 0.0249
Seen so far: 128 samples
Training loss at step 200: 0.0403
Seen so far: 25728 samples
Training loss at step 400: 0.0324
Seen so far: 51328 samples
Training acc over epoch: 0.9862
Validation acc: 0.9693
Time taken: 1.37s

Start of epoch 5
Training loss at step 0: 0.0279
Seen so far: 128 samples
Training loss at step 200: 0.0529
Seen so far: 25728 samples
Training loss at step 400: 0.0592
Seen so far: 51328 samples
Training acc over epoch: 0.9875
Validation acc: 0.9660
Time taken: 1.32s

Start of epoch 6
Training loss at step 0: 0.0200
Seen so far: 128 samples
Training loss at step 200: 0.0344
Seen so far: 25728 samples
Training loss at step 400: 0.0142
Seen so far: 51328 samples
Training acc over epoch: 0.9884
Validation acc: 0.9720
Time taken: 1.36s

Start of epoch 7
Training loss at step 0: 0.0325
Seen so far: 128 samples
Training loss at step 200: 0.0499
Seen so far: 25728 samples
Training loss at step 400: 0.0167
Seen so far: 51328 samples
Training acc over epoch: 0.9898
Validation acc: 0.9711
Time taken: 1.39s

Start of epoch 8
Training loss at step 0: 0.0155
Seen so far: 128 samples
Training loss at step 200: 0.0219
Seen so far: 25728 samples
Training loss at step 400: 0.0213
Seen so far: 51328 samples
Training acc over epoch: 0.9914
Validation acc: 0.9715
Time taken: 1.34s

Start of epoch 9
Training loss at step 0: 0.0068
Seen so far: 128 samples
Training loss at step 200: 0.0288
Seen so far: 25728 samples
Training loss at step 400: 0.0320
Seen so far: 51328 samples
Training acc over epoch: 0.9921
Validation acc: 0.9743
Time taken: 1.34s

Start of epoch 10
Training loss at step 0: 0.0069
Seen so far: 128 samples
Training loss at step 200: 0.0419
Seen so far: 25728 samples
Training loss at step 400: 0.0458
Seen so far: 51328 samples
Training acc over epoch: 0.9919
Validation acc: 0.9720
Time taken: 1.34s

Start of epoch 11
Training loss at step 0: 0.0015
Seen so far: 128 samples
Training loss at step 200: 0.0113
Seen so far: 25728 samples
Training loss at step 400: 0.0285
Seen so far: 51328 samples
Training acc over epoch: 0.9930
Validation acc: 0.9740
Time taken: 1.33s

Start of epoch 12
Training loss at step 0: 0.0010
Seen so far: 128 samples
Training loss at step 200: 0.0450
Seen so far: 25728 samples
Training loss at step 400: 0.0561
Seen so far: 51328 samples
Training acc over epoch: 0.9926
Validation acc: 0.9744
Time taken: 1.34s

Start of epoch 13
Training loss at step 0: 0.0038
Seen so far: 128 samples
Training loss at step 200: 0.0309
Seen so far: 25728 samples
Training loss at step 400: 0.0124
Seen so far: 51328 samples
Training acc over epoch: 0.9937
Validation acc: 0.9740
Time taken: 1.36s

Start of epoch 14
Training loss at step 0: 0.0011
Seen so far: 128 samples
Training loss at step 200: 0.0068
Seen so far: 25728 samples
Training loss at step 400: 0.0110
Seen so far: 51328 samples
Training acc over epoch: 0.9930
Validation acc: 0.9736
Time taken: 1.35s

Start of epoch 15
Training loss at step 0: 0.0043
Seen so far: 128 samples
Training loss at step 200: 0.0085
Seen so far: 25728 samples
Training loss at step 400: 0.0042
Seen so far: 51328 samples
Training acc over epoch: 0.9942
Validation acc: 0.9752
Time taken: 1.34s

Start of epoch 16
Training loss at step 0: 0.0208
Seen so far: 128 samples
Training loss at step 200: 0.0042
Seen so far: 25728 samples
Training loss at step 400: 0.1063
Seen so far: 51328 samples
Training acc over epoch: 0.9947
Validation acc: 0.9740
Time taken: 1.34s

Start of epoch 17
Training loss at step 0: 0.0067
Seen so far: 128 samples
Training loss at step 200: 0.0277
Seen so far: 25728 samples
Training loss at step 400: 0.0787
Seen so far: 51328 samples
Training acc over epoch: 0.9951
Validation acc: 0.9729
Time taken: 1.33s

Start of epoch 18
Training loss at step 0: 0.0017
Seen so far: 128 samples
Training loss at step 200: 0.0131
Seen so far: 25728 samples
Training loss at step 400: 0.0431
Seen so far: 51328 samples
Training acc over epoch: 0.9943
Validation acc: 0.9739
Time taken: 1.40s

Start of epoch 19
Training loss at step 0: 0.0004
Seen so far: 128 samples
Training loss at step 200: 0.0220
Seen so far: 25728 samples
Training loss at step 400: 0.0662
Seen so far: 51328 samples
Training acc over epoch: 0.9952
Validation acc: 0.9738
Time taken: 1.34s

Start of epoch 20
Training loss at step 0: 0.0003
Seen so far: 128 samples
Training loss at step 200: 0.0306
Seen so far: 25728 samples
Training loss at step 400: 0.0083
Seen so far: 51328 samples
Training acc over epoch: 0.9955
Validation acc: 0.9753
Time taken: 1.37s

Start of epoch 21
Training loss at step 0: 0.0016
Seen so far: 128 samples
Training loss at step 200: 0.0003
Seen so far: 25728 samples
Training loss at step 400: 0.0069
Seen so far: 51328 samples
Training acc over epoch: 0.9946
Validation acc: 0.9729
Time taken: 1.36s

Start of epoch 22
Training loss at step 0: 0.0626
Seen so far: 128 samples
Training loss at step 200: 0.0013
Seen so far: 25728 samples
Training loss at step 400: 0.0278
Seen so far: 51328 samples
Training acc over epoch: 0.9946
Validation acc: 0.9740
Time taken: 1.34s

Start of epoch 23
Training loss at step 0: 0.0318
Seen so far: 128 samples
Training loss at step 200: 0.0514
Seen so far: 25728 samples
Training loss at step 400: 0.0001
Seen so far: 51328 samples
Training acc over epoch: 0.9952
Validation acc: 0.9758
Time taken: 1.36s

Start of epoch 24
Training loss at step 0: 0.0004
Seen so far: 128 samples
Training loss at step 200: 0.0043
Seen so far: 25728 samples
Training loss at step 400: 0.0339
Seen so far: 51328 samples
Training acc over epoch: 0.9956
Validation acc: 0.9752
Time taken: 1.37s

Start of epoch 25
Training loss at step 0: 0.0000
Seen so far: 128 samples
Training loss at step 200: 0.0057
Seen so far: 25728 samples
Training loss at step 400: 0.1485
Seen so far: 51328 samples
Training acc over epoch: 0.9961
Validation acc: 0.9733
Time taken: 1.35s

Start of epoch 26
Training loss at step 0: 0.0005
Seen so far: 128 samples
Training loss at step 200: 0.0992
Seen so far: 25728 samples
Training loss at step 400: 0.0033
Seen so far: 51328 samples
Training acc over epoch: 0.9972
Validation acc: 0.9776
Time taken: 1.32s

Start of epoch 27
Training loss at step 0: 0.0004
Seen so far: 128 samples
Training loss at step 200: 0.0402
Seen so far: 25728 samples
Training loss at step 400: 0.0002
Seen so far: 51328 samples
Training acc over epoch: 0.9966
Validation acc: 0.9784
Time taken: 1.34s

Start of epoch 28
Training loss at step 0: 0.0005
Seen so far: 128 samples
Training loss at step 200: 0.0307
Seen so far: 25728 samples
Training loss at step 400: 0.1466
Seen so far: 51328 samples
Training acc over epoch: 0.9948
Validation acc: 0.9742
Time taken: 1.37s

Start of epoch 29
Training loss at step 0: 0.0092
Seen so far: 128 samples
Training loss at step 200: 0.0039
Seen so far: 25728 samples
Training loss at step 400: 0.0358
Seen so far: 51328 samples
Training acc over epoch: 0.9959
Validation acc: 0.9791
Time taken: 1.33s

Next steps