tfp.substrates.jax.glm.fit_one_step

Runs one step of Fisher scoring.

model_matrix (Batch of) float-like, matrix-shaped Tensor where each row represents a sample's features.
response (Batch of) vector-shaped Tensor where each element represents a sample's observed response (to the corresponding row of features). Must have same dtype as model_matrix.
model tfp.glm.ExponentialFamily-like instance used to construct the negative log-likelihood loss, gradient, and expected Hessian (i.e., the Fisher information matrix).
model_coefficients_start Optional (batch of) vector-shaped Tensor representing the initial model coefficients, one for each column in model_matrix. Must have same dtype as model_matrix. Default value: Zeros.
predicted_linear_response_start Optional Tensor with shape, dtype matching response; represents offset shifted initial linear predictions based on model_coefficients_start. Default value: offset if model_coefficients is None, and tf.linalg.matvec(model_matrix, model_coefficients_start) + offset otherwise.
l2_regularizer Optional scalar Tensor representing L2 regularization penalty, i.e., loss(w) = sum{-log p(y[i]|x[i],w) : i=1..n} + l2_regularizer ||w||_2^2. Default value: None (i.e., no L2 regularization).
dispersion Optional (batch of) Tensor representing response dispersion, i.e., as in, p(y|theta) := exp((y theta - A(theta)) / dispersion). Must broadcast with rows of model_matrix. Default value: None (i.e., "no dispersion").
offset Optional Tensor representing constant shift applied to predicted_linear_response. Must broadcast to response. Default value: None (i.e., tf.zeros_like(response)).
learning_rate Optional (batch of) scalar Tensor used to dampen iterative progress. Typically only needed if optimization diverges, should be no larger than 1 and typically very close to 1. Default value: None (i.e., 1).
fast_unsafe_numerics Optional Python bool indicating if solve should be based on Cholesky or QR decomposition. Default value: True (i.e., "prefer speed via Cholesky decomposition").
l2_regularization_penalty_factor Optional (batch of) vector-shaped Tensor, representing a separate penalty factor to apply to each model coefficient, length equal to columns in model_matrix. Each penalty factor multiplies l2_regularizer to allow differential regularization. Can be 0 for some coefficients, which implies no regularization. Default is 1 for all coefficients. loss(w) = sum{-log p(y[i]|x[i],w) : i=1..n} + l2_regularizer ||w * l2_regularization_penalty_factor||_2^2
name Python str used as name prefix to ops created by this function. Default value: "fit_one_step".

model_coefficients (Batch of) vector-shaped Tensor; represents the next estimate of the model coefficients, one for each column in model_matrix.
predicted_linear_response response-shaped Tensor representing linear predictions based on new model_coefficients, i.e., tf.linalg.matvec(model_matrix, model_coefficients_next) + offset.