tf.contrib.kfac.fisher_blocks.FullyConnectedDiagonalFB

Class FullyConnectedDiagonalFB

Defined in tensorflow/contrib/kfac/python/ops/fisher_blocks.py.

FisherBlock for fully-connected (dense) layers using a diagonal approx.

Estimates the Fisher Information matrix's diagonal entries for a fully connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" estimator.

Let 'params' be a vector parameterizing a model and 'i' an arbitrary index into it. We are interested in Fisher(params)[i, i]. This is,

$$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] = E[ v(x, y, params)[i] ^ 2 ]$$

Consider fully connected layer in this model with (unshared) weight matrix 'w'. For an example 'x' that produces layer inputs 'a' and output preactivations 's',

$$v(x, y, w) = vec( a (d loss / d s)^T )$$

This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding to the layer's parameters 'w'.

Properties

num_registered_towers

Methods

__init__

__init__(
    layer_collection,
    has_bias=False
)

Creates a FullyConnectedDiagonalFB block.

Args:

  • layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs.
  • has_bias: Whether the component Kronecker factors have an additive bias. (Default: False)

full_fisher_block

full_fisher_block()

instantiate_factors

instantiate_factors(
    grads_list,
    damping
)

Creates and registers the component factors of this Fisher block.

Args:

  • grads_list: A list gradients (each a Tensor or tuple of Tensors) with respect to the tensors returned by tensors_to_compute_grads() that are to be used to estimate the block.
  • damping: The damping factor (float or Tensor).

multiply

multiply(vector)

Multiplies the vector by the (damped) block.

Args:

  • vector: The vector (a Tensor or tuple of Tensors) to be multiplied.

Returns:

The vector left-multiplied by the (damped) block.

multiply_cholesky

multiply_cholesky(
    vector,
    transpose=False
)

Multiplies the vector by the (damped) Cholesky-factor of the block.

Args:

  • vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
  • transpose: Bool. If true the Cholesky factor is transposed before multiplying the vector. (Default: False)

Returns:

The vector left-multiplied by the (damped) Cholesky-factor of the block.

multiply_cholesky_inverse

multiply_cholesky_inverse(
    vector,
    transpose=False
)

Multiplies vector by the (damped) inverse Cholesky-factor of the block.

Args:

  • vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
  • transpose: Bool. If true the Cholesky factor inverse is transposed before multiplying the vector. (Default: False)

Returns:

Vector left-multiplied by (damped) inverse Cholesky-factor of the block.

multiply_inverse

multiply_inverse(vector)

Multiplies the vector by the (damped) inverse of the block.

Args:

  • vector: The vector (a Tensor or tuple of Tensors) to be multiplied.

Returns:

The vector left-multiplied by the (damped) inverse of the block.

multiply_matpower

multiply_matpower(
    vector,
    exp
)

Multiplies the vector by the (damped) matrix-power of the block.

Args:

  • vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
  • exp: A float representing the power to raise the block by before multiplying it by the vector.

Returns:

The vector left-multiplied by the (damped) matrix-power of the block.

register_additional_tower

register_additional_tower(
    inputs,
    outputs
)

register_cholesky

register_cholesky()

Registers a Cholesky factor to be computed by the block.

register_cholesky_inverse

register_cholesky_inverse()

Registers an inverse Cholesky factor to be computed by the block.

register_inverse

register_inverse()

Registers a matrix inverse to be computed by the block.

register_matpower

register_matpower(exp)

Registers a matrix power to be computed by the block.

Args:

  • exp: A float representing the power to raise the block by.

tensors_to_compute_grads

tensors_to_compute_grads()

Tensors to compute derivative of loss with respect to.