The KFAC Optimizer (https://arxiv.org/abs/1503.05671).
__init__( learning_rate, cov_ema_decay, damping, layer_collection, var_list=None, momentum=0.9, momentum_type='regular', norm_constraint=None, name='KFAC', estimation_mode='gradients', colocate_gradients_with_ops=True, batch_size=None, placement_strategy=None, **kwargs )
Initializes the KFAC optimizer with the given settings.
learning_rate: The base learning rate for the optimizer. Should probably be set to 1.0 when using momentum_type = 'qmodel', but can still be set lowered if desired (effectively lowering the trust in the quadratic model.)
cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages.
damping: The damping factor used to stabilize training due to errors in the local approximation with the Fisher information matrix, and to regularize the update direction by making it closer to the gradient. If damping is adapted during training then this value is used for initializing damping variable. (Higher damping means the update looks more like a standard gradient update - see Tikhonov regularization.)
layer_collection: The layer collection object, which holds the fisher blocks, kronecker factors, and losses associated with the graph. The layer_collection cannot be modified after KfacOptimizer's initialization.
var_list: Optional list or tuple of variables to train. Defaults to the list of variables collected in the graph under the key
momentum: The momentum decay constant to use. Only applies when momentum_type is 'regular' or 'adam'. (Default: 0.9)
momentum_type: The type of momentum to use in this optimizer, one of 'regular', 'adam', or 'qmodel'. (Default: 'regular')
norm_constraint: float or Tensor. If specified, the update is scaled down so that its approximate squared Fisher norm v^T F v is at most the specified value. May only be used with momentum type 'regular'. (Default: None)
name: The name for this optimizer. (Default: 'KFAC')
estimation_mode: The type of estimator to use for the Fishers. Can be 'gradients', 'empirical', 'curvature_propagation', or 'exact'. (Default: 'gradients'). See the doc-string for FisherEstimator for more a more detailed description of these options.
colocate_gradients_with_ops: Whether we should request gradients we compute in the estimator be colocated with their respective ops. (Default: True)
batch_size: The size of the mini-batch. Only needed when momentum_type == 'qmodel' or when automatic adjustment is used. (Default: None)
placement_strategy: string, Device placement strategy used when creating covariance variables, covariance ops, and inverse ops. (Default:
**kwargs: Arguments to be passesd to specific placement strategy mixin. Check
ValueError: If the momentum type is unsupported.
ValueError: If clipping is used with momentum type other than 'regular'.
ValueError: If no losses have been registered with layer_collection.
ValueError: If momentum is non-zero and momentum_type is not 'regular' or 'adam'.
apply_gradients( grads_and_vars, *args, **kwargs )
Applies gradients to variables.
grads_and_vars: List of (gradient, variable) pairs.
*args: Additional arguments for super.apply_gradients.
**kwargs: Additional keyword arguments for super.apply_gradients.
Operation that applies the specified gradients.
compute_gradients( *args, **kwargs )
Create thunks that make the ops and vars on demand.
This function returns 4 lists of thunks: cov_variable_thunks, cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
The length of each list is the number of factors and the i-th element of each list corresponds to the i-th factor (given by the "factors" property).
Note that the execution of these thunks must happen in a certain partial order. The i-th element of cov_variable_thunks must execute before the i-th element of cov_update_thunks (and also the i-th element of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks must execute before the i-th element of inv_update_thunks.
TL;DR (oversimplified): Execute the thunks according to the order that they are returned.
cov_variable_thunks: A list of thunks that make the cov variables.
cov_update_thunks: A list of thunks that make the cov update ops.
inv_variable_thunks: A list of thunks that make the inv variables.
inv_update_thunks: A list of thunks that make the inv update ops.
get_slot( var, name )
Return a slot named
name created for
var by the Optimizer.
Optimizer subclasses use additional variables. For example
Adagrad use variables to accumulate updates. This method
gives access to these
Variable objects if for some reason you need them.
get_slot_names() to get the list of slot names created by the
var: A variable passed to
name: A string.
Variable for the slot if it was created,
Return a list of the names of slots created by the
A list of strings.
Make vars and create op thunks.
cov_update_thunks: List of cov update thunks. Corresponds one-to-one with the list of factors given by the "factors" property.
inv_update_thunks: List of inv update thunks. Corresponds one-to-one with the list of factors given by the "factors" property.
minimize( *args, **kwargs )
set_damping_adaptation_params( is_chief, prev_train_batch, loss_fn, min_damping=1e-05, damping_adaptation_decay=0.99, damping_adaptation_interval=5 )
Sets parameters required to adapt damping during training.
When called, enables damping adaptation according to the Levenberg-Marquardt style rule described in Section 6.5 of "Optimizing Neural Networks with Kronecker-factored Approximate Curvature".
Note that this function creates Tensorflow variables which store a few scalars and are accessed by the ops which update the damping (as part of the training op returned by the minimize() method).
Trueif the worker is chief.
prev_train_batch: Training data used to minimize loss in the previous step. This will be used to evaluate loss by calling
functionthat takes as input training data tensor and returns a scalar loss.
float(Optional), Minimum value the damping parameter can take. Default value 1e-5.
dampingparameter is multiplied by the
damping_adaptation_intervalnumber of iterations. Default value 0.99.
int(Optional), Number of steps in between updating the
dampingparameter. Default value 5.
set_damping_adaptation_paramsis already called and the the