![]() |
![]() |
![]() |
![]() |
This notebook demonstrates how you can set up model training with early stopping, first, in TensorFlow 1 with tf.estimator.Estimator
and an early stopping hook, and then, in TensorFlow 2 with Keras APIs or a custom training loop. Early stopping is a regularization technique that stops training if, for example, the validation loss reaches a certain threshold.
In TensorFlow 2, there are three ways to implement early stopping:
- Use a built-in Keras callback—
tf.keras.callbacks.EarlyStopping
—and pass it toModel.fit
. - Define a custom callback and pass it to Keras
Model.fit
. - Write a custom early stopping rule in a custom training loop (with
tf.GradientTape
).
Setup
import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
2023-01-21 02:20:59.919521: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2023-01-21 02:20:59.919649: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2023-01-21 02:20:59.919660: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
TensorFlow 1: Early stopping with an early stopping hook and tf.estimator
Start by defining functions for MNIST dataset loading and preprocessing, and model definition to be used with tf.estimator.Estimator
:
def normalize_img(image, label):
return tf.cast(image, tf.float32) / 255., label
def _input_fn():
ds_train = tfds.load(
name='mnist',
split='train',
shuffle_files=True,
as_supervised=True)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_train = ds_train.repeat(100)
return ds_train
def _eval_input_fn():
ds_test = tfds.load(
name='mnist',
split='test',
shuffle_files=True,
as_supervised=True)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
return ds_test
def _model_fn(features, labels, mode):
flatten = tf1.layers.Flatten()(features)
features = tf1.layers.Dense(128, 'relu')(flatten)
logits = tf1.layers.Dense(10)(features)
loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
optimizer = tf1.train.AdagradOptimizer(0.005)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
In TensorFlow 1, early stopping works by setting up an early stopping hook with tf.estimator.experimental.make_early_stopping_hook
. You pass the hook to the make_early_stopping_hook
method as a parameter for should_stop_fn
, which can accept a function without any arguments. The training stops once should_stop_fn
returns True
.
The following example demonstrates how to implement an early stopping technique that limits the training time to a maximum of 20 seconds:
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
start_time = time.time()
max_train_seconds = 20
def should_stop_fn():
return time.time() - start_time > max_train_seconds
early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
estimator=estimator,
should_stop_fn=should_stop_fn,
run_every_secs=1,
run_every_steps=None)
train_spec = tf1.estimator.TrainSpec(
input_fn=_input_fn,
hooks=[early_stopping_hook])
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpfmt3ke80 INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpfmt3ke80', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpfmt3ke80/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpfmt3ke80/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 2.3043585, step = 0 INFO:tensorflow:loss = 2.3043585, step = 0 INFO:tensorflow:global_step/sec: 144.825 INFO:tensorflow:global_step/sec: 144.825 INFO:tensorflow:loss = 1.3497735, step = 100 (0.693 sec) INFO:tensorflow:loss = 1.3497735, step = 100 (0.693 sec) INFO:tensorflow:global_step/sec: 154.727 INFO:tensorflow:global_step/sec: 154.727 INFO:tensorflow:loss = 0.83353037, step = 200 (0.646 sec) INFO:tensorflow:loss = 0.83353037, step = 200 (0.646 sec) INFO:tensorflow:global_step/sec: 160.674 INFO:tensorflow:global_step/sec: 160.674 INFO:tensorflow:loss = 0.69762164, step = 300 (0.622 sec) INFO:tensorflow:loss = 0.69762164, step = 300 (0.622 sec) INFO:tensorflow:global_step/sec: 151.078 INFO:tensorflow:global_step/sec: 151.078 INFO:tensorflow:loss = 0.6154139, step = 400 (0.662 sec) INFO:tensorflow:loss = 0.6154139, step = 400 (0.662 sec) INFO:tensorflow:global_step/sec: 373.93 INFO:tensorflow:global_step/sec: 373.93 INFO:tensorflow:loss = 0.50315976, step = 500 (0.267 sec) INFO:tensorflow:loss = 0.50315976, step = 500 (0.267 sec) INFO:tensorflow:global_step/sec: 517.621 INFO:tensorflow:global_step/sec: 517.621 INFO:tensorflow:loss = 0.4667952, step = 600 (0.193 sec) INFO:tensorflow:loss = 0.4667952, step = 600 (0.193 sec) INFO:tensorflow:global_step/sec: 556.511 INFO:tensorflow:global_step/sec: 556.511 INFO:tensorflow:loss = 0.38532913, step = 700 (0.180 sec) INFO:tensorflow:loss = 0.38532913, step = 700 (0.180 sec) INFO:tensorflow:global_step/sec: 557.793 INFO:tensorflow:global_step/sec: 557.793 INFO:tensorflow:loss = 0.51158535, step = 800 (0.179 sec) INFO:tensorflow:loss = 0.51158535, step = 800 (0.179 sec) INFO:tensorflow:global_step/sec: 553.041 INFO:tensorflow:global_step/sec: 553.041 INFO:tensorflow:loss = 0.402203, step = 900 (0.181 sec) INFO:tensorflow:loss = 0.402203, step = 900 (0.181 sec) INFO:tensorflow:global_step/sec: 517.58 INFO:tensorflow:global_step/sec: 517.58 INFO:tensorflow:loss = 0.44644293, step = 1000 (0.193 sec) INFO:tensorflow:loss = 0.44644293, step = 1000 (0.193 sec) INFO:tensorflow:global_step/sec: 544.171 INFO:tensorflow:global_step/sec: 544.171 INFO:tensorflow:loss = 0.43521973, step = 1100 (0.184 sec) INFO:tensorflow:loss = 0.43521973, step = 1100 (0.184 sec) INFO:tensorflow:global_step/sec: 546.196 INFO:tensorflow:global_step/sec: 546.196 INFO:tensorflow:loss = 0.38868037, step = 1200 (0.183 sec) INFO:tensorflow:loss = 0.38868037, step = 1200 (0.183 sec) INFO:tensorflow:global_step/sec: 545.356 INFO:tensorflow:global_step/sec: 545.356 INFO:tensorflow:loss = 0.49474105, step = 1300 (0.183 sec) INFO:tensorflow:loss = 0.49474105, step = 1300 (0.183 sec) INFO:tensorflow:global_step/sec: 533.842 INFO:tensorflow:global_step/sec: 533.842 INFO:tensorflow:loss = 0.2977963, step = 1400 (0.188 sec) INFO:tensorflow:loss = 0.2977963, step = 1400 (0.188 sec) INFO:tensorflow:global_step/sec: 381.327 INFO:tensorflow:global_step/sec: 381.327 INFO:tensorflow:loss = 0.29854277, step = 1500 (0.262 sec) INFO:tensorflow:loss = 0.29854277, step = 1500 (0.262 sec) INFO:tensorflow:global_step/sec: 394.517 INFO:tensorflow:global_step/sec: 394.517 INFO:tensorflow:loss = 0.4113207, step = 1600 (0.253 sec) INFO:tensorflow:loss = 0.4113207, step = 1600 (0.253 sec) INFO:tensorflow:global_step/sec: 405.687 INFO:tensorflow:global_step/sec: 405.687 INFO:tensorflow:loss = 0.38139153, step = 1700 (0.247 sec) INFO:tensorflow:loss = 0.38139153, step = 1700 (0.247 sec) INFO:tensorflow:global_step/sec: 569.242 INFO:tensorflow:global_step/sec: 569.242 INFO:tensorflow:loss = 0.32674897, step = 1800 (0.176 sec) INFO:tensorflow:loss = 0.32674897, step = 1800 (0.176 sec) INFO:tensorflow:global_step/sec: 500.692 INFO:tensorflow:global_step/sec: 500.692 INFO:tensorflow:loss = 0.5069632, step = 1900 (0.199 sec) INFO:tensorflow:loss = 0.5069632, step = 1900 (0.199 sec) INFO:tensorflow:global_step/sec: 499.402 INFO:tensorflow:global_step/sec: 499.402 INFO:tensorflow:loss = 0.21091068, step = 2000 (0.200 sec) INFO:tensorflow:loss = 0.21091068, step = 2000 (0.200 sec) INFO:tensorflow:global_step/sec: 551.063 INFO:tensorflow:global_step/sec: 551.063 INFO:tensorflow:loss = 0.29082954, step = 2100 (0.181 sec) INFO:tensorflow:loss = 0.29082954, step = 2100 (0.181 sec) INFO:tensorflow:global_step/sec: 550.014 INFO:tensorflow:global_step/sec: 550.014 INFO:tensorflow:loss = 0.3142205, step = 2200 (0.182 sec) INFO:tensorflow:loss = 0.3142205, step = 2200 (0.182 sec) INFO:tensorflow:global_step/sec: 554.924 INFO:tensorflow:global_step/sec: 554.924 INFO:tensorflow:loss = 0.354444, step = 2300 (0.181 sec) INFO:tensorflow:loss = 0.354444, step = 2300 (0.181 sec) INFO:tensorflow:global_step/sec: 515.242 INFO:tensorflow:global_step/sec: 515.242 INFO:tensorflow:loss = 0.26232147, step = 2400 (0.194 sec) INFO:tensorflow:loss = 0.26232147, step = 2400 (0.194 sec) INFO:tensorflow:global_step/sec: 560.106 INFO:tensorflow:global_step/sec: 560.106 INFO:tensorflow:loss = 0.2223329, step = 2500 (0.179 sec) INFO:tensorflow:loss = 0.2223329, step = 2500 (0.179 sec) INFO:tensorflow:global_step/sec: 491.955 INFO:tensorflow:global_step/sec: 491.955 INFO:tensorflow:loss = 0.16123277, step = 2600 (0.203 sec) INFO:tensorflow:loss = 0.16123277, step = 2600 (0.203 sec) INFO:tensorflow:global_step/sec: 528.073 INFO:tensorflow:global_step/sec: 528.073 INFO:tensorflow:loss = 0.31785184, step = 2700 (0.189 sec) INFO:tensorflow:loss = 0.31785184, step = 2700 (0.189 sec) INFO:tensorflow:global_step/sec: 519.496 INFO:tensorflow:global_step/sec: 519.496 INFO:tensorflow:loss = 0.48224646, step = 2800 (0.193 sec) INFO:tensorflow:loss = 0.48224646, step = 2800 (0.193 sec) INFO:tensorflow:global_step/sec: 482.332 INFO:tensorflow:global_step/sec: 482.332 INFO:tensorflow:loss = 0.2396833, step = 2900 (0.207 sec) INFO:tensorflow:loss = 0.2396833, step = 2900 (0.207 sec) INFO:tensorflow:global_step/sec: 509.391 INFO:tensorflow:global_step/sec: 509.391 INFO:tensorflow:loss = 0.3452201, step = 3000 (0.196 sec) INFO:tensorflow:loss = 0.3452201, step = 3000 (0.196 sec) INFO:tensorflow:global_step/sec: 503.545 INFO:tensorflow:global_step/sec: 503.545 INFO:tensorflow:loss = 0.20264056, step = 3100 (0.199 sec) INFO:tensorflow:loss = 0.20264056, step = 3100 (0.199 sec) INFO:tensorflow:global_step/sec: 520.823 INFO:tensorflow:global_step/sec: 520.823 INFO:tensorflow:loss = 0.40972877, step = 3200 (0.192 sec) INFO:tensorflow:loss = 0.40972877, step = 3200 (0.192 sec) INFO:tensorflow:global_step/sec: 520.413 INFO:tensorflow:global_step/sec: 520.413 INFO:tensorflow:loss = 0.3450895, step = 3300 (0.192 sec) INFO:tensorflow:loss = 0.3450895, step = 3300 (0.192 sec) INFO:tensorflow:global_step/sec: 557.182 INFO:tensorflow:global_step/sec: 557.182 INFO:tensorflow:loss = 0.26727343, step = 3400 (0.180 sec) INFO:tensorflow:loss = 0.26727343, step = 3400 (0.180 sec) INFO:tensorflow:global_step/sec: 566.333 INFO:tensorflow:global_step/sec: 566.333 INFO:tensorflow:loss = 0.2099815, step = 3500 (0.176 sec) INFO:tensorflow:loss = 0.2099815, step = 3500 (0.176 sec) INFO:tensorflow:global_step/sec: 569.768 INFO:tensorflow:global_step/sec: 569.768 INFO:tensorflow:loss = 0.23265842, step = 3600 (0.175 sec) INFO:tensorflow:loss = 0.23265842, step = 3600 (0.175 sec) INFO:tensorflow:global_step/sec: 552.542 INFO:tensorflow:global_step/sec: 552.542 INFO:tensorflow:loss = 0.28018022, step = 3700 (0.181 sec) INFO:tensorflow:loss = 0.28018022, step = 3700 (0.181 sec) INFO:tensorflow:global_step/sec: 521.827 INFO:tensorflow:global_step/sec: 521.827 INFO:tensorflow:loss = 0.37057906, step = 3800 (0.192 sec) INFO:tensorflow:loss = 0.37057906, step = 3800 (0.192 sec) INFO:tensorflow:global_step/sec: 542.979 INFO:tensorflow:global_step/sec: 542.979 INFO:tensorflow:loss = 0.2255877, step = 3900 (0.184 sec) INFO:tensorflow:loss = 0.2255877, step = 3900 (0.184 sec) INFO:tensorflow:global_step/sec: 550.828 INFO:tensorflow:global_step/sec: 550.828 INFO:tensorflow:loss = 0.30667296, step = 4000 (0.182 sec) INFO:tensorflow:loss = 0.30667296, step = 4000 (0.182 sec) INFO:tensorflow:global_step/sec: 544.601 INFO:tensorflow:global_step/sec: 544.601 INFO:tensorflow:loss = 0.21357399, step = 4100 (0.183 sec) INFO:tensorflow:loss = 0.21357399, step = 4100 (0.183 sec) INFO:tensorflow:global_step/sec: 555.71 INFO:tensorflow:global_step/sec: 555.71 INFO:tensorflow:loss = 0.2577279, step = 4200 (0.180 sec) INFO:tensorflow:loss = 0.2577279, step = 4200 (0.180 sec) INFO:tensorflow:global_step/sec: 446.746 INFO:tensorflow:global_step/sec: 446.746 INFO:tensorflow:loss = 0.29612792, step = 4300 (0.223 sec) INFO:tensorflow:loss = 0.29612792, step = 4300 (0.223 sec) INFO:tensorflow:global_step/sec: 546.035 INFO:tensorflow:global_step/sec: 546.035 INFO:tensorflow:loss = 0.31375682, step = 4400 (0.183 sec) INFO:tensorflow:loss = 0.31375682, step = 4400 (0.183 sec) INFO:tensorflow:global_step/sec: 510.825 INFO:tensorflow:global_step/sec: 510.825 INFO:tensorflow:loss = 0.28181487, step = 4500 (0.196 sec) INFO:tensorflow:loss = 0.28181487, step = 4500 (0.196 sec) INFO:tensorflow:global_step/sec: 557.316 INFO:tensorflow:global_step/sec: 557.316 INFO:tensorflow:loss = 0.31139958, step = 4600 (0.180 sec) INFO:tensorflow:loss = 0.31139958, step = 4600 (0.180 sec) INFO:tensorflow:global_step/sec: 505.992 INFO:tensorflow:global_step/sec: 505.992 INFO:tensorflow:loss = 0.15156099, step = 4700 (0.197 sec) INFO:tensorflow:loss = 0.15156099, step = 4700 (0.197 sec) INFO:tensorflow:global_step/sec: 489.912 INFO:tensorflow:global_step/sec: 489.912 INFO:tensorflow:loss = 0.2878572, step = 4800 (0.204 sec) INFO:tensorflow:loss = 0.2878572, step = 4800 (0.204 sec) INFO:tensorflow:global_step/sec: 555.884 INFO:tensorflow:global_step/sec: 555.884 INFO:tensorflow:loss = 0.38200092, step = 4900 (0.180 sec) INFO:tensorflow:loss = 0.38200092, step = 4900 (0.180 sec) INFO:tensorflow:global_step/sec: 521.052 INFO:tensorflow:global_step/sec: 521.052 INFO:tensorflow:loss = 0.27179155, step = 5000 (0.192 sec) INFO:tensorflow:loss = 0.27179155, step = 5000 (0.192 sec) INFO:tensorflow:global_step/sec: 541.078 INFO:tensorflow:global_step/sec: 541.078 INFO:tensorflow:loss = 0.33555904, step = 5100 (0.185 sec) INFO:tensorflow:loss = 0.33555904, step = 5100 (0.185 sec) INFO:tensorflow:global_step/sec: 516.851 INFO:tensorflow:global_step/sec: 516.851 INFO:tensorflow:loss = 0.21896324, step = 5200 (0.194 sec) INFO:tensorflow:loss = 0.21896324, step = 5200 (0.194 sec) INFO:tensorflow:global_step/sec: 559.338 INFO:tensorflow:global_step/sec: 559.338 INFO:tensorflow:loss = 0.25059485, step = 5300 (0.179 sec) INFO:tensorflow:loss = 0.25059485, step = 5300 (0.179 sec) INFO:tensorflow:global_step/sec: 464.363 INFO:tensorflow:global_step/sec: 464.363 INFO:tensorflow:loss = 0.15162958, step = 5400 (0.215 sec) INFO:tensorflow:loss = 0.15162958, step = 5400 (0.215 sec) INFO:tensorflow:global_step/sec: 563.826 INFO:tensorflow:global_step/sec: 563.826 INFO:tensorflow:loss = 0.2413958, step = 5500 (0.177 sec) INFO:tensorflow:loss = 0.2413958, step = 5500 (0.177 sec) INFO:tensorflow:global_step/sec: 561.753 INFO:tensorflow:global_step/sec: 561.753 INFO:tensorflow:loss = 0.18371066, step = 5600 (0.179 sec) INFO:tensorflow:loss = 0.18371066, step = 5600 (0.179 sec) INFO:tensorflow:global_step/sec: 502.932 INFO:tensorflow:global_step/sec: 502.932 INFO:tensorflow:loss = 0.17668869, step = 5700 (0.199 sec) INFO:tensorflow:loss = 0.17668869, step = 5700 (0.199 sec) INFO:tensorflow:Requesting early stopping at global step 5768 INFO:tensorflow:Requesting early stopping at global step 5768 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5769... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5769... INFO:tensorflow:Saving checkpoints for 5769 into /tmpfs/tmp/tmpfmt3ke80/model.ckpt. INFO:tensorflow:Saving checkpoints for 5769 into /tmpfs/tmp/tmpfmt3ke80/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5769... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5769... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2023-01-21T02:21:22 INFO:tensorflow:Starting evaluation at 2023-01-21T02:21:22 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpfmt3ke80/model.ckpt-5769 INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpfmt3ke80/model.ckpt-5769 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Inference Time : 1.06422s INFO:tensorflow:Inference Time : 1.06422s INFO:tensorflow:Finished evaluation at 2023-01-21-02:21:23 INFO:tensorflow:Finished evaluation at 2023-01-21-02:21:23 INFO:tensorflow:Saving dict for global step 5769: global_step = 5769, loss = 0.24380085 INFO:tensorflow:Saving dict for global step 5769: global_step = 5769, loss = 0.24380085 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5769: /tmpfs/tmp/tmpfmt3ke80/model.ckpt-5769 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5769: /tmpfs/tmp/tmpfmt3ke80/model.ckpt-5769 INFO:tensorflow:Loss for final step: 0.22827312. INFO:tensorflow:Loss for final step: 0.22827312. ({'loss': 0.24380085, 'global_step': 5769}, [])
TensorFlow 2: Early stopping with a built-in callback and Model.fit
Prepare the MNIST dataset and a simple Keras model:
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.005),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
In TensorFlow 2, when you use the built-in Keras Model.fit
(or Model.evaluate
), you can configure early stopping by passing a built-in callback—tf.keras.callbacks.EarlyStopping
—to the callbacks
parameter of Model.fit
.
The EarlyStopping
callback monitors a user-specified metric and ends training when it stops improving. (Check the Training and evaluation with the built-in methods or the API docs for more information.)
Below is an example of an early stopping callback that monitors the loss and stops training after the number of epochs that show no improvements is set to 3
(patience
):
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 3s 5ms/step - loss: 0.2276 - sparse_categorical_accuracy: 0.9324 - val_loss: 0.1130 - val_sparse_categorical_accuracy: 0.9663 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0958 - sparse_categorical_accuracy: 0.9711 - val_loss: 0.1038 - val_sparse_categorical_accuracy: 0.9689 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0658 - sparse_categorical_accuracy: 0.9800 - val_loss: 0.1069 - val_sparse_categorical_accuracy: 0.9677 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0486 - sparse_categorical_accuracy: 0.9847 - val_loss: 0.0962 - val_sparse_categorical_accuracy: 0.9730 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0393 - sparse_categorical_accuracy: 0.9873 - val_loss: 0.1036 - val_sparse_categorical_accuracy: 0.9726 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0378 - sparse_categorical_accuracy: 0.9872 - val_loss: 0.1097 - val_sparse_categorical_accuracy: 0.9736 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0333 - sparse_categorical_accuracy: 0.9891 - val_loss: 0.1141 - val_sparse_categorical_accuracy: 0.9726 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0300 - sparse_categorical_accuracy: 0.9904 - val_loss: 0.1099 - val_sparse_categorical_accuracy: 0.9756 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0216 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1128 - val_sparse_categorical_accuracy: 0.9754 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0234 - sparse_categorical_accuracy: 0.9926 - val_loss: 0.1548 - val_sparse_categorical_accuracy: 0.9711 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0236 - sparse_categorical_accuracy: 0.9921 - val_loss: 0.1405 - val_sparse_categorical_accuracy: 0.9720 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0235 - sparse_categorical_accuracy: 0.9921 - val_loss: 0.1290 - val_sparse_categorical_accuracy: 0.9752 12
TensorFlow 2: Early stopping with a custom callback and Model.fit
You can also implement a custom early stopping callback, which can also be passed to the callbacks
parameter of Model.fit
(or Model.evaluate
).
In this example, the training process is stopped once self.model.stop_training
is set to be True
:
class LimitTrainingTime(tf.keras.callbacks.Callback):
def __init__(self, max_time_s):
super().__init__()
self.max_time_s = max_time_s
self.start_time = None
def on_train_begin(self, logs):
self.start_time = time.time()
def on_train_batch_end(self, batch, logs):
now = time.time()
if now - self.start_time > self.max_time_s:
self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0216 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1371 - val_sparse_categorical_accuracy: 0.9749 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0207 - sparse_categorical_accuracy: 0.9934 - val_loss: 0.1495 - val_sparse_categorical_accuracy: 0.9738 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0213 - sparse_categorical_accuracy: 0.9929 - val_loss: 0.1516 - val_sparse_categorical_accuracy: 0.9758 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9955 - val_loss: 0.1699 - val_sparse_categorical_accuracy: 0.9743 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0177 - sparse_categorical_accuracy: 0.9942 - val_loss: 0.1528 - val_sparse_categorical_accuracy: 0.9767 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0183 - sparse_categorical_accuracy: 0.9944 - val_loss: 0.1708 - val_sparse_categorical_accuracy: 0.9744 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0164 - sparse_categorical_accuracy: 0.9950 - val_loss: 0.1825 - val_sparse_categorical_accuracy: 0.9760 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0130 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.1519 - val_sparse_categorical_accuracy: 0.9799 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9949 - val_loss: 0.1943 - val_sparse_categorical_accuracy: 0.9755 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0184 - sparse_categorical_accuracy: 0.9947 - val_loss: 0.2014 - val_sparse_categorical_accuracy: 0.9763 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0158 - sparse_categorical_accuracy: 0.9953 - val_loss: 0.1790 - val_sparse_categorical_accuracy: 0.9783 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0157 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1899 - val_sparse_categorical_accuracy: 0.9766 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0130 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2224 - val_sparse_categorical_accuracy: 0.9740 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0173 - sparse_categorical_accuracy: 0.9955 - val_loss: 0.2038 - val_sparse_categorical_accuracy: 0.9754 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0166 - sparse_categorical_accuracy: 0.9955 - val_loss: 0.2048 - val_sparse_categorical_accuracy: 0.9761 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0086 - sparse_categorical_accuracy: 0.9972 - val_loss: 0.2055 - val_sparse_categorical_accuracy: 0.9756 Epoch 17/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0133 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2759 - val_sparse_categorical_accuracy: 0.9735 Epoch 18/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.2361 - val_sparse_categorical_accuracy: 0.9763 Epoch 19/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0134 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2329 - val_sparse_categorical_accuracy: 0.9759 Epoch 20/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.2216 - val_sparse_categorical_accuracy: 0.9780 Epoch 21/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0153 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2341 - val_sparse_categorical_accuracy: 0.9775 Epoch 22/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0111 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.2499 - val_sparse_categorical_accuracy: 0.9785 Epoch 23/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0124 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.2589 - val_sparse_categorical_accuracy: 0.9761 Epoch 24/100 469/469 [==============================] - 1s 2ms/step - loss: 0.0092 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.2432 - val_sparse_categorical_accuracy: 0.9761 24
TensorFlow 2: Early stopping with a custom training loop
In TensorFlow 2, you can implement early stopping in a custom training loop if you're not training and evaluating with the built-in Keras methods.
Start by using Keras APIs to define another simple model, an optimizer, a loss function, and metrics:
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
Define the parameter update functions with tf.GradientTape and the @tf.function
decorator for a speedup:
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
train_loss_metric.update_state(y, logits)
return loss_value
@tf.function
def test_step(x, y):
logits = model(x, training=False)
val_acc_metric.update_state(y, logits)
val_loss_metric.update_state(y, logits)
Next, write a custom training loop, where you can implement your early stopping rule manually.
The example below shows how to stop training when the validation loss doesn't improve over a certain number of epochs:
epochs = 100
patience = 5
wait = 0
best = float('inf')
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
loss_value = train_step(x_batch_train, y_batch_train)
if step % 200 == 0:
print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
print("Seen so far: %s samples" % ((step + 1) * 128))
train_acc = train_acc_metric.result()
train_loss = train_loss_metric.result()
train_acc_metric.reset_states()
train_loss_metric.reset_states()
print("Training acc over epoch: %.4f" % (train_acc.numpy()))
for x_batch_val, y_batch_val in ds_test:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_loss = val_loss_metric.result()
val_acc_metric.reset_states()
val_loss_metric.reset_states()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
# The early stopping strategy: stop the training if `val_loss` does not
# decrease over a certain number of epochs.
wait += 1
if val_loss < best:
best = val_loss
wait = 0
if wait >= patience:
break
Start of epoch 0 Training loss at step 0: 2.3311 Seen so far: 128 samples Training loss at step 200: 0.2537 Seen so far: 25728 samples Training loss at step 400: 0.1951 Seen so far: 51328 samples Training acc over epoch: 0.9314 Validation acc: 0.9633 Time taken: 2.06s Start of epoch 1 Training loss at step 0: 0.0846 Seen so far: 128 samples Training loss at step 200: 0.1519 Seen so far: 25728 samples Training loss at step 400: 0.1162 Seen so far: 51328 samples Training acc over epoch: 0.9691 Validation acc: 0.9686 Time taken: 1.06s Start of epoch 2 Training loss at step 0: 0.0470 Seen so far: 128 samples Training loss at step 200: 0.0991 Seen so far: 25728 samples Training loss at step 400: 0.0771 Seen so far: 51328 samples Training acc over epoch: 0.9790 Validation acc: 0.9660 Time taken: 1.04s Start of epoch 3 Training loss at step 0: 0.0474 Seen so far: 128 samples Training loss at step 200: 0.0519 Seen so far: 25728 samples Training loss at step 400: 0.0462 Seen so far: 51328 samples Training acc over epoch: 0.9829 Validation acc: 0.9638 Time taken: 1.03s Start of epoch 4 Training loss at step 0: 0.0320 Seen so far: 128 samples Training loss at step 200: 0.0360 Seen so far: 25728 samples Training loss at step 400: 0.0182 Seen so far: 51328 samples Training acc over epoch: 0.9853 Validation acc: 0.9657 Time taken: 1.03s Start of epoch 5 Training loss at step 0: 0.0278 Seen so far: 128 samples Training loss at step 200: 0.0686 Seen so far: 25728 samples Training loss at step 400: 0.0664 Seen so far: 51328 samples Training acc over epoch: 0.9868 Validation acc: 0.9659 Time taken: 1.04s Start of epoch 6 Training loss at step 0: 0.0167 Seen so far: 128 samples Training loss at step 200: 0.0385 Seen so far: 25728 samples Training loss at step 400: 0.0482 Seen so far: 51328 samples Training acc over epoch: 0.9887 Validation acc: 0.9737 Time taken: 0.98s Start of epoch 7 Training loss at step 0: 0.0147 Seen so far: 128 samples Training loss at step 200: 0.0165 Seen so far: 25728 samples Training loss at step 400: 0.0097 Seen so far: 51328 samples Training acc over epoch: 0.9899 Validation acc: 0.9731 Time taken: 1.01s Start of epoch 8 Training loss at step 0: 0.0041 Seen so far: 128 samples Training loss at step 200: 0.0058 Seen so far: 25728 samples Training loss at step 400: 0.0252 Seen so far: 51328 samples Training acc over epoch: 0.9905 Validation acc: 0.9679 Time taken: 1.03s Start of epoch 9 Training loss at step 0: 0.0100 Seen so far: 128 samples Training loss at step 200: 0.0199 Seen so far: 25728 samples Training loss at step 400: 0.0175 Seen so far: 51328 samples Training acc over epoch: 0.9923 Validation acc: 0.9733 Time taken: 1.00s Start of epoch 10 Training loss at step 0: 0.0068 Seen so far: 128 samples Training loss at step 200: 0.0332 Seen so far: 25728 samples Training loss at step 400: 0.0410 Seen so far: 51328 samples Training acc over epoch: 0.9911 Validation acc: 0.9739 Time taken: 1.07s Start of epoch 11 Training loss at step 0: 0.0028 Seen so far: 128 samples Training loss at step 200: 0.0451 Seen so far: 25728 samples Training loss at step 400: 0.0141 Seen so far: 51328 samples Training acc over epoch: 0.9925 Validation acc: 0.9734 Time taken: 0.99s Start of epoch 12 Training loss at step 0: 0.0010 Seen so far: 128 samples Training loss at step 200: 0.0115 Seen so far: 25728 samples Training loss at step 400: 0.0929 Seen so far: 51328 samples Training acc over epoch: 0.9933 Validation acc: 0.9727 Time taken: 0.99s Start of epoch 13 Training loss at step 0: 0.0189 Seen so far: 128 samples Training loss at step 200: 0.0242 Seen so far: 25728 samples Training loss at step 400: 0.0257 Seen so far: 51328 samples Training acc over epoch: 0.9938 Validation acc: 0.9720 Time taken: 1.07s Start of epoch 14 Training loss at step 0: 0.0009 Seen so far: 128 samples Training loss at step 200: 0.0633 Seen so far: 25728 samples Training loss at step 400: 0.0052 Seen so far: 51328 samples Training acc over epoch: 0.9936 Validation acc: 0.9727 Time taken: 1.00s Start of epoch 15 Training loss at step 0: 0.0097 Seen so far: 128 samples Training loss at step 200: 0.0048 Seen so far: 25728 samples Training loss at step 400: 0.0323 Seen so far: 51328 samples Training acc over epoch: 0.9940 Validation acc: 0.9730 Time taken: 1.01s Start of epoch 16 Training loss at step 0: 0.0166 Seen so far: 128 samples Training loss at step 200: 0.0132 Seen so far: 25728 samples Training loss at step 400: 0.0085 Seen so far: 51328 samples Training acc over epoch: 0.9942 Validation acc: 0.9750 Time taken: 1.00s Start of epoch 17 Training loss at step 0: 0.0008 Seen so far: 128 samples Training loss at step 200: 0.0027 Seen so far: 25728 samples Training loss at step 400: 0.0224 Seen so far: 51328 samples Training acc over epoch: 0.9945 Validation acc: 0.9748 Time taken: 1.03s Start of epoch 18 Training loss at step 0: 0.0231 Seen so far: 128 samples Training loss at step 200: 0.0143 Seen so far: 25728 samples Training loss at step 400: 0.0177 Seen so far: 51328 samples Training acc over epoch: 0.9951 Validation acc: 0.9765 Time taken: 1.05s Start of epoch 19 Training loss at step 0: 0.0026 Seen so far: 128 samples Training loss at step 200: 0.0316 Seen so far: 25728 samples Training loss at step 400: 0.1067 Seen so far: 51328 samples Training acc over epoch: 0.9944 Validation acc: 0.9760 Time taken: 1.02s Start of epoch 20 Training loss at step 0: 0.0001 Seen so far: 128 samples Training loss at step 200: 0.0022 Seen so far: 25728 samples Training loss at step 400: 0.0137 Seen so far: 51328 samples Training acc over epoch: 0.9948 Validation acc: 0.9720 Time taken: 1.04s Start of epoch 21 Training loss at step 0: 0.0018 Seen so far: 128 samples Training loss at step 200: 0.0508 Seen so far: 25728 samples Training loss at step 400: 0.0885 Seen so far: 51328 samples Training acc over epoch: 0.9946 Validation acc: 0.9744 Time taken: 1.06s Start of epoch 22 Training loss at step 0: 0.0559 Seen so far: 128 samples Training loss at step 200: 0.0095 Seen so far: 25728 samples Training loss at step 400: 0.0086 Seen so far: 51328 samples Training acc over epoch: 0.9954 Validation acc: 0.9741 Time taken: 1.07s Start of epoch 23 Training loss at step 0: 0.0979 Seen so far: 128 samples Training loss at step 200: 0.0015 Seen so far: 25728 samples Training loss at step 400: 0.0300 Seen so far: 51328 samples Training acc over epoch: 0.9963 Validation acc: 0.9730 Time taken: 1.04s Start of epoch 24 Training loss at step 0: 0.0020 Seen so far: 128 samples Training loss at step 200: 0.0551 Seen so far: 25728 samples Training loss at step 400: 0.0254 Seen so far: 51328 samples Training acc over epoch: 0.9958 Validation acc: 0.9762 Time taken: 1.05s Start of epoch 25 Training loss at step 0: 0.0002 Seen so far: 128 samples Training loss at step 200: 0.0230 Seen so far: 25728 samples Training loss at step 400: 0.0410 Seen so far: 51328 samples Training acc over epoch: 0.9962 Validation acc: 0.9718 Time taken: 1.09s Start of epoch 26 Training loss at step 0: 0.0001 Seen so far: 128 samples Training loss at step 200: 0.0016 Seen so far: 25728 samples Training loss at step 400: 0.0001 Seen so far: 51328 samples Training acc over epoch: 0.9957 Validation acc: 0.9730 Time taken: 1.03s Start of epoch 27 Training loss at step 0: 0.0008 Seen so far: 128 samples Training loss at step 200: 0.0919 Seen so far: 25728 samples Training loss at step 400: 0.0024 Seen so far: 51328 samples Training acc over epoch: 0.9956 Validation acc: 0.9726 Time taken: 1.04s
Next steps
- Learn more about the Keras built-in early stopping callback API in the API docs.
- Learn to write custom Keras callbacks, including early stopping at a minimum loss.
- Learn about Training and evaluation with the Keras built-in methods.
- Explore common regularization techniques in the Overfit and underfit tutorial that uses the
EarlyStopping
callback.