Attend the Women in ML Symposium on December 7 Register now

tfm.optimization.OptimizerFactory

Stay organized with collections Save and categorize content based on your preferences.

Optimizer factory class.

This class builds learning rate and optimizer based on an optimization config. To use this class, you need to do the following: (1) Define optimization config, this includes optimizer, and learning rate schedule. (2) Initialize the class using the optimization config. (3) Build learning rate. (4) Build optimizer.

This is a typical example for using this class:

params = {
      'optimizer': {
          'type': 'sgd',
          'sgd': {'momentum': 0.9}
      },
      'learning_rate': {
          'type': 'stepwise',
          'stepwise': {'boundaries': [10000, 20000],
                       'values': [0.1, 0.01, 0.001]}
      },
      'warmup': {
          'type': 'linear',
          'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01}
      }
  }
opt_config = OptimizationConfig(params)
opt_factory = OptimizerFactory(opt_config)
lr = opt_factory.build_learning_rate()
optimizer = opt_factory.build_optimizer(lr)

config OptimizationConfig instance contain optimization config.

Methods

build_learning_rate

View source

Build learning rate.

Builds learning rate from config. Learning rate schedule is built according to the learning rate config. If learning rate type is consant, lr_config.learning_rate is returned.

Returns
tf.keras.optimizers.schedules.LearningRateSchedule instance. If learning rate type is consant, lr_config.learning_rate is returned.

build_optimizer

View source

Build optimizer.

Builds optimizer from config. It takes learning rate as input, and builds the optimizer according to the optimizer config. Typically, the learning rate built using self.build_lr() is passed as an argument to this method.

Args
lr A floating point value, or a tf.keras.optimizers.schedules.LearningRateSchedule instance.
gradient_aggregator Optional function to overwrite gradient aggregation.
gradient_transformers Optional list of functions to use to transform gradients before applying updates to Variables. The functions are applied after gradient_aggregator. The functions should accept and return a list of (gradient, variable) tuples. clipvalue, clipnorm, global_clipnorm should not be set when gradient_transformers is passed.
postprocessor An optional function for postprocessing the optimizer. It takes an optimizer and returns an optimizer.
use_legacy_optimizer A boolean that indicates if using legacy optimizers.

Returns
tf.keras.optimizers.legacy.Optimizer or tf.keras.optimizers.experimental.Optimizer instance.