It implements the decoupled weight decay described by Loshchilov & Hutter
(https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is
decoupled from the optimization steps w.r.t. to the loss function.
For SGD variants, this simplifies hyperparameter search since it decouples
the settings of weight decay and learning rate.
For adaptive gradient algorithms, it regularizes variables with large
gradients more than L2 regularization would, which was shown to yield
better training loss and generalization error in the paper above.
This class alone is not an optimizer but rather extends existing
optimizers with decoupled weight decay. We explicitly define the two
examples used in the above paper (SGDW and AdamW), but in general this can
extend any OptimizerX class by using
ExtendedCls = extend_with_decoupled_weight_decay(OptimizerX).
Weight decay can then be set when instantiating the optimizer:
optimizerX = ExtendedCls(weight_decay=0.001, learning_rate=0.001).
In order for it to work, it must be the first class the Optimizer with
weight decay inherits from, e.g.
step = tf.Variable(0, trainable=False)
schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
[10000, 15000], [1e-0, 1e-1, 1e-2])
# lr and wd can be a function or a tensor
lr = 1e-1 * schedule(step)
wd = lambda: 1e-4 * schedule(step)
optimizer = tfa.optimizers.AdamW(learning_rate=lr, weight_decay=wd)
This method simply computes gradient using tf.GradientTape and calls
apply_gradients(). If you want to process the gradient before
applying then call tf.GradientTape and apply_gradients() explicitly
instead of using this function.
A callable taking no arguments which returns the value to
list or tuple of Variable objects to update to
minimize loss, or a callable returning the list or tuple of
Variable objects. Use callable when the variable list would
otherwise be incomplete before minimize since the variables
are created at the first time loss is called.
Optional. A Tensor holding the gradient computed for
Optional list of variables to be decayed. Defaults
to all variables in var_list.
Optional name for the returned operation.
An Operation that updates the variables in var_list.
If some of the variables are not Variable objects.