Class for computing matrix-vector products for Fishers, GGNs and Hessians.
In other words we compute M*v where M is the matrix, v is the vector, and * refers to standard matrix/vector multiplication (not element-wise multiplication).
The matrices are defined in terms of some differential quantity of the total loss function with respect to a provided list of tensors ("wrt_tensors"). For example, the Fisher associated with a log-prob loss w.r.t. the parameters.
The 'vecs' argument to each method are lists of tensors that must be the size as the corresponding ones from "wrt_tensors". They represent the vector being multiplied.
"factors" of the matrix M are defined as matrices B such that B*B^T = M. Methods that multiply by the factor B take a 'loss_inner_vecs' argument instead of 'vecs', which must be a list of tensors with shapes given by the corresponding XXX_inner_shapes property.
Note that matrix-vector products are not normalized by the batch size, nor are any damping terms added to the results. These things can be easily applied externally, if desired.
See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf and https://arxiv.org/abs/1412.1193 for more information about the generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector products.
Shapes required by multiply_fisher_factor.
Shapes required by multiply_generalized_gauss_newton_factor.
__init__( losses, wrt_tensors )
Create a CurvatureMatrixVectorProductComputer object.
losses: A list of LossFunction instances whose sum defines the total loss.
wrt_tensors: A list of Tensors to compute the differential quantities (defining the matrices) with respect to. See class description for more info.
Multiply vecs by Fisher of total loss.
Multiply loss_inner_vecs by factor of Fisher of total loss.
Multiply vecs by transpose of factor of Fisher of total loss.
Multiply vecs by generalized Gauss-Newton of total loss.
Multiply loss_inner_vecs by factor of GGN of total loss.
Multiply vecs by transpose of factor of GGN of total loss.
Multiply vecs by Hessian of total loss.