tfa.optimizers.ExponentialCyclicalLearningRate

View source on GitHub

Class ExponentialCyclicalLearningRate

A LearningRateSchedule that uses cyclical schedule.

Inherits From: CyclicalLearningRate

__init__

View source

__init__(
    initial_learning_rate,
    maximal_learning_rate,
    step_size,
    scale_mode='iterations',
    gamma=1.0,
    name='ExponentialCyclicalLearningRate'
)

Applies exponential cyclical schedule to the learning rate.

See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186

from tf.keras.optimizers import schedules

lr_schedule = ExponentialCyclicalLearningRate(
    initial_learning_rate=1e-4,
    maximal_learning_rate=1e-2,
    step_size=2000,
    scale_mode="cycle",
    gamma=0.96,
    name="MyCyclicScheduler")

model.compile(optimizer=tf.keras.optimizers.SGD(
                                            learning_rate=lr_schedule),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(data, labels, epochs=5)

You can pass this schedule directly into a tf.keras.optimizers.Optimizer as the learning rate.

Args:

  • initial_learning_rate: A scalar float32 or float64 Tensor or a Python number. The initial learning rate.
  • maximal_learning_rate: A scalar float32 or float64 Tensor or a Python number. The maximum learning rate.
  • step_size: A scalar float32 or float64 Tensor or a Python number. Step size.
  • scale_fn: A function. Scheduling function applied in cycle
  • scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic schedule
  • gamma: A scalar float32 or float64 Tensor or a Python number. Gamma value.
  • name: (Optional) Name for the operation.

Returns:

Updated learning rate value.

Methods

__call__

View source

__call__(step)

Call self as a function.

from_config

@classmethod
from_config(
    cls,
    config
)

Instantiates a LearningRateSchedule from its config.

Args:

  • config: Output of get_config().

Returns:

A LearningRateSchedule instance.

get_config

View source

get_config()