Class ConvDiagonalFactor
Inherits From: DiagonalFactor
Defined in tensorflow/contrib/kfac/python/ops/fisher_factors.py
.
FisherFactor for a diagonal approx of a convolutional layer's Fisher.
Properties
name
Methods
__init__
__init__(
inputs,
outputs_grads,
filter_shape,
strides,
padding,
data_format=None,
dilations=None,
has_bias=False
)
Creates a ConvDiagonalFactor object.
Args:
inputs
: List of Tensors of shape [batch_size, height, width, in_channels]. Input activations to this layer. List index is towers.outputs_grads
: List of Tensors, each of shape [batch_size, height, width, out_channels], which are the gradients of the loss with respect to the layer's outputs. First index is source, second index is tower.filter_shape
: Tuple of 4 ints: (kernel_height, kernel_width, in_channels, out_channels). Represents shape of kernel used in this layer.strides
: The stride size in this layer (1-D Tensor of length 4).padding
: The padding in this layer (1-D of Tensor length 4).data_format
: None or str. Format of conv2d inputs.dilations
: None or tuple of 4 ints.has_bias
: Python bool. If True, the layer is assumed to have a bias parameter in addition to its filter parameter.
Raises:
ValueError
: If inputs, output_grads, and filter_shape do not agree on in_channels or out_channels.ValueError
: If strides, dilations are not length-4 lists of ints.ValueError
: If data_format does not put channel last.
get_cholesky
get_cholesky(damping_func)
get_cholesky_inverse
get_cholesky_inverse(damping_func)
get_cov
get_cov()
get_cov_as_linear_operator
get_cov_as_linear_operator()
get_matpower
get_matpower(
exp,
damping_func
)
instantiate_cov_variables
instantiate_cov_variables()
Makes the internal cov variable(s).
instantiate_inv_variables
instantiate_inv_variables()
make_covariance_update_op
make_covariance_update_op(ema_decay)
make_inverse_update_ops
make_inverse_update_ops()
register_cholesky
register_cholesky(damping_func)
register_cholesky_inverse
register_cholesky_inverse(damping_func)
register_matpower
register_matpower(
exp,
damping_func
)