tf.contrib.kfac.fisher_factors.ConvInputKroneckerFactor

Class ConvInputKroneckerFactor

Defined in tensorflow/contrib/kfac/python/ops/fisher_factors.py.

Kronecker factor for the input side of a convolutional layer.

Estimates E[ a a^T ] where a is the inputs to a convolutional layer given example x. Expectation is taken over all examples and locations.

Equivalent to Omega in https://arxiv.org/abs/1602.01407 for details. See Section 3.1 Estimating the factors.

Properties

name

Methods

__init__

__init__(
    inputs,
    filter_shape,
    padding,
    strides=None,
    dilation_rate=None,
    data_format=None,
    extract_patches_fn=None,
    has_bias=False,
    sub_sample_inputs=None,
    sub_sample_patches=None
)

Initializes ConvInputKroneckerFactor.

Args:

  • inputs: List of Tensors of shape [batch_size, ..spatial_input_size.., in_channels]. Inputs to layer. List index is tower.
  • filter_shape: List of ints. Contains [..spatial_filter_size.., in_channels, out_channels]. Shape of convolution kernel.
  • padding: str. Padding method for layer. "SAME" or "VALID".
  • strides: List of ints or None. Contains [..spatial_filter_strides..] if 'extract_patches_fn' is compatible with tf.nn.convolution(), else [1, ..spatial_filter_strides, 1].
  • dilation_rate: List of ints or None. Rate for dilation along each spatial dimension if 'extract_patches_fn' is compatible with tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
  • data_format: str or None. Format of input data.
  • extract_patches_fn: str or None. Name of function that extracts image patches. One of "extract_convolution_patches", "extract_image_patches", "extract_pointwise_conv2d_patches".
  • has_bias: bool. If True, append 1 to in_channel.
  • sub_sample_inputs: bool. If True, then subsample the inputs from which the image patches are extracted. (Default: None)
  • sub_sample_patches: bool, If True then subsample the extracted patches.(Default: None)

get_cholesky

get_cholesky(damping_func)

get_cholesky_inverse

get_cholesky_inverse(damping_func)

get_cov

get_cov()

get_cov_as_linear_operator

get_cov_as_linear_operator()

get_eigendecomp

get_eigendecomp()

Creates or retrieves eigendecomposition of self._cov.

get_inverse

get_inverse(damping_func)

get_matpower

get_matpower(
    exp,
    damping_func
)

instantiate_cov_variables

instantiate_cov_variables()

Makes the internal cov variable(s).

instantiate_inv_variables

instantiate_inv_variables()

Makes the internal "inverse" variable(s).

make_covariance_update_op

make_covariance_update_op(ema_decay)

Constructs and returns the covariance update Op.

Args:

  • ema_decay: The exponential moving average decay (float or Tensor).

Returns:

An Op for updating the covariance Variable referenced by _cov.

make_inverse_update_ops

make_inverse_update_ops()

Create and return update ops corresponding to registered computations.

register_cholesky

register_cholesky(damping_func)

Registers a Cholesky factor to be maintained and served on demand.

This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method get_cholesky.

Args:

  • damping_func: A function that computes a 0-D Tensor or a float which will be the damping value used. i.e. damping = damping_func().

register_cholesky_inverse

register_cholesky_inverse(damping_func)

Registers an inverse Cholesky factor to be maintained/served on demand.

This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method get_cholesky_inverse.

Args:

  • damping_func: A function that computes a 0-D Tensor or a float which will be the damping value used. i.e. damping = damping_func().

register_inverse

register_inverse(damping_func)

register_matpower

register_matpower(
    exp,
    damping_func
)

Registers a matrix power to be maintained and served on demand.

This creates a variable and signals make_inverse_update_ops to make the corresponding update op. The variable can be read via the method get_matpower.

Args:

  • exp: float. The exponent to use in the matrix power.
  • damping_func: A function that computes a 0-D Tensor or a float which will be the damping value used. i.e. damping = damping_func().