tf.contrib.kfac.estimator.FisherEstimator

Class FisherEstimator

Defined in tensorflow/contrib/kfac/python/ops/estimator.py.

Fisher estimator class supporting various approximations of the Fisher.

Attributes:

  • cov_update_thunks: list of no-arg functions. Executing a function adds covariance update ops for a single FisherFactor to the graph.
  • cov_update_ops: List of Ops. Running an op updates covariance matrices for a single FisherFactor.
  • cov_update_op: Op. Running updates covariance matrices for all FisherFactors.
  • inv_update_thunks: list of no-arg functions. Executing a function adds inverse update ops for a single FisherFactor to the graph.
  • inv_update_ops: List of Ops. Running an op updates inverse matrices for a single FisherFactor.
  • inv_update_op: Op. Running updates inverse matrices for all FisherFactors.

Properties

damping

variables

Methods

__init__

__init__(
    damping_fn,
    variables,
    cov_ema_decay,
    layer_collection,
    estimation_mode='gradients',
    colocate_gradients_with_ops=True,
    cov_devices=None,
    inv_devices=None
)

Create a FisherEstimator object.

Args:

  • damping_fn: Function, accepts no arguments and returns damping value.
  • variables: A list of the variables for which to estimate the Fisher. This must match the variables registered in layer_collection (if it is not None).
  • cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages.
  • layer_collection: The layer collection object, which holds the fisher blocks, kronecker factors, and losses associated with the graph.
  • estimation_mode: The type of estimator to use for the Fishers. Can be 'gradients', 'empirical', 'curvature_prop', or 'exact'. (Default: 'gradients'). 'gradients' is the basic estimation approach from the original K-FAC paper. 'empirical' computes the 'empirical' Fisher information matrix (which uses the data's distribution for the targets, as opposed to the true Fisher which uses the model's distribution) and requires that each registered loss have specified targets. 'curvature_propagation' is a method which estimates the Fisher using self-products of random 1/-1 vectors times "half-factors" of the Fisher, as described here: https://arxiv.org/abs/1206.6464 . Finally, 'exact' is the obvious generalization of Curvature Propagation to compute the exact Fisher (modulo any additional diagonal or Kronecker approximations) by looping over one-hot vectors for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking.
  • colocate_gradients_with_ops: Whether we should request gradients be colocated with their respective ops. (Default: True)
  • cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified.
  • inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified.

Raises:

  • ValueError: If no losses have been registered with layer_collection.

multiply

multiply(vecs_and_vars)

Multiplies the vectors by the corresponding (damped) blocks.

Args:

  • vecs_and_vars: List of (vector, variable) pairs.

Returns:

A list of (transformed vector, var) pairs in the same order as vecs_and_vars.

multiply_inverse

multiply_inverse(vecs_and_vars)

Multiplies the vecs by the corresponding (damped) inverses of the blocks.

Args:

  • vecs_and_vars: List of (vector, variable) pairs.

Returns:

A list of (transformed vector, var) pairs in the same order as vecs_and_vars.