Registration is open for TensorFlow Dev Summit 2020 Learn more

tfa.optimizers.Triangular2CyclicalLearningRate

View source on GitHub

Class Triangular2CyclicalLearningRate

A LearningRateSchedule that uses cyclical schedule.

Inherits From: CyclicalLearningRate

Aliases: tfa.optimizers.cyclical_learning_rate.Triangular2CyclicalLearningRate

__init__

View source

__init__(
    initial_learning_rate,
    maximal_learning_rate,
    step_size,
    scale_mode='cycle',
    name='Triangular2CyclicalLearningRate'
)

Applies triangular2 cyclical schedule to the learning rate.

See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186

from tf.keras.optimizers import schedules

lr_schedule = schedules.Triangular2CyclicalLearningRate(
    initial_learning_rate=1e-4,
    maximal_learning_rate=1e-2,
    step_size=2000,
    scale_mode="cycle",
    name="MyCyclicScheduler")

model.compile(optimizer=tf.keras.optimizers.SGD(
                                            learning_rate=lr_schedule),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(data, labels, epochs=5)

You can pass this schedule directly into a tf.keras.optimizers.Optimizer as the learning rate.

Args:

  • initial_learning_rate: A scalar float32 or float64 Tensor or a Python number. The initial learning rate.
  • maximal_learning_rate: A scalar float32 or float64 Tensor or a Python number. The maximum learning rate.
  • step_size: A scalar float32 or float64 Tensor or a Python number. Step size.
  • scale_fn: A function. Scheduling function applied in cycle
  • scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic schedule
  • name: (Optional) Name for the operation.

Returns:

Updated learning rate value.

Methods

__call__

View source

__call__(step)

Call self as a function.

from_config

from_config(
    cls,
    config
)

Instantiates a LearningRateSchedule from its config.

Args:

  • config: Output of get_config().

Returns:

A LearningRateSchedule instance.

get_config

View source

get_config()