Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf.keras.callbacks.LearningRateScheduler

TensorFlow 1 version View source on GitHub

Learning rate scheduler.

Inherits From: Callback

tf.keras.callbacks.LearningRateScheduler(
    schedule, verbose=0
)

Used in the notebooks

Used in the tutorials

Arguments:

  • schedule: a function that takes an epoch index as input (integer, indexed from 0) and returns a new learning rate as output (float).
  • verbose: int. 0: quiet, 1: update messages.
# This function keeps the learning rate at 0.001 for the first ten epochs
# and decreases it exponentially after that.
def scheduler(epoch):
  if epoch < 10:
    return 0.001
  else:
    return 0.001 * tf.math.exp(0.1 * (10 - epoch))

callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
model.fit(data, labels, epochs=100, callbacks=[callback],
          validation_data=(val_data, val_labels))

Methods

set_model

View source

set_model(
    model
)

set_params

View source

set_params(
    params
)