ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


Abstract base class used to build new callbacks.

Used in the notebooks

Used in the guide Used in the tutorials

Callbacks can be passed to keras methods such as fit, evaluate, and predict in order to hook into the various stages of the model training and inference lifecycle.

To create a custom callback, subclass keras.callbacks.Callback and override the method associated with the stage of interest. See for more information.


training_finished = False
class MyCallback(tf.keras.callbacks.Callback):
  def on_train_end(self, logs=None):
    global training_finished
    training_finished = True
model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
model.compile(loss='mean_squared_error')[[1.0]]), tf.constant([[1.0]]),
assert training_finished == True

If you want to use Callback objects in a custom training loop:

  1. You should pack all your callbacks into a single