tf_agents.distributions.utils.merge_to_parameters_from_dict

Merges dict matching data of parameters_to_dict(value) to a new Params.

For more details, see the example below and the documentation of parameters_to_dict.

Example:

scale_matrix = tf.Variable([[1.0, 2.0], [-1.0, 0.0]])
d = tfp.distributions.MultivariateNormalDiag(
    loc=[1.0, 1.0], scale_diag=[2.0, 3.0], validate_args=True)
b = tfp.bijectors.ScaleMatvecLinearOperator(
    scale=tf.linalg.LinearOperatorFullMatrix(matrix=scale_matrix),
    adjoint=True)
b_d = b(d)
p = utils.get_parameters(b_d)

params_dict = utils.parameters_to_dict(p)
params_dict["bijector"]["scale"]["matrix"] = new_scale_matrix

new_params = utils.merge_to_parameters_from_dict(
  p, params_dict)

# new_d is a `ScaleMatvecLinearOperator()(MultivariateNormalDiag)` with
# a new scale matrix.
new_d = utils.make_from_parameters(new_params)

value A Params from which params_dict was derived.
params_dict A nested dict created by e.g. calling parameters_to_dict(value) and modifying it to modify parameters. NOTE If any keys in the dict are missing, the "default" value in value is used instead.

A new Params object which can then be turned into e.g. a tfp.Distribution via make_from_parameters.

ValueError If params_dict has keys missing from value.params.
KeyError If a subdict entry is missing for a nested value in value.params.