tf_agents.policies.PolicySaver

A PolicySaver allows you to save a tf_policy.Policy to SavedModel.

Used in the notebooks

Used in the tutorials

The save() method exports a saved model to the requested export location. The SavedModel that is exported can be loaded via tf.compat.v2.saved_model.load (or tf.saved_model.load in TF2). The following signatures (concrete functions) are available: action, get_initial_state, and get_train_step.

The attribute model_variables is also available when the saved_model is loaded which gives access to model variables in order to update them if needed.

Usage:


my_policy = agent.collect_policy
saver = PolicySaver(my_policy, batch_size=None)

for i in range(...):
  agent.train(...)
  if i % 100 == 0:
    saver.save('policy_%d' % global_step)

To load and use the saved policy directly:

saved_policy = tf.compat.v2.saved_model.load('policy_0')
policy_state = saved_policy.get_initial_state(batch_size=3)
time_step = ...
while True:
  policy_step = saved_policy.action(time_step, policy_state)
  policy_state = policy_step.state
  time_step = f(policy_step.action)
  ...

or to use the distributional form, e.g.:

batch_size = 3
saved_policy = tf.compat.v2.saved_model.load('policy_0')
policy_state = saved_policy.get_initial_state(batch_size=batch_size)
time_step = ...
while True:
  policy_step = saved_policy.distribution(time_step, policy_state)
  policy_state = policy_step.state
  time_step = f(policy_step.action.sample(batch_size))
  ...

If using the flattened (signature) version, you will be limited to using dicts keyed by the specs' name fields.

saved_policy = tf.compat.v2.saved_model.load('policy_0')
get_initial_state_fn = saved_policy.signatures['get_initial_state']
action_fn = saved_policy.signatures['action']

policy_state_dict = get_initial_state_fn(batch_size=3)
time_step_dict = ...
while True:
  time_step_state = dict(time_step_dict)
  time_step_state.update(policy_state_dict)
  policy_step_dict = action_fn(time_step_state)
  policy_state_dict = extract_policy_state_fields(policy_step_dict)
  action_dict = extract_action_fields(policy_step_dict)
  time_step_dict = f(action_dict)
  ...

policy A TF Policy.
batch_size The number of batch entries the policy will process at a time. This must be either None (unknown batch size) or a python integer.
use_nest_path_signatures SavedModel spec signatures will be created based on the sructure of the specs. Otherwise all specs must have unique names.
seed Random seed for the policy.action call, if any (this should usually be None, except for testing).
train_step Variable holding the train step for the policy. The value saved will be set at the time saver.save is called. If not provided, train_step defaults to -1. Note since the train step must be a variable it is not safe to create it directly in TF1 so in that case this is a required parameter.
input_fn_and_spec A (input_fn, tensor_spec) tuple where input_fn is a function that takes inputs according to tensor_spec and converts them to the (time_step, policy_state) tuple that is used as the input to the action_fn. When input_fn_and_spec is set, tensor_spec is the input for the action signature. When input_fn_and_spec is None, the action signature takes as input (time_step, policy_state).
metadata A dictionary of tf.Variables to be saved along with the policy.

TypeError If policy is not an instance of TFPolicy.
TypeError If metadata is not a dictionary of tf.Variables.
ValueError If use_nest_path_signatures is not used and any of the following policy specs are missing names, or the names collide: policy.time_step_spec, policy.action_spec, policy.policy_state_spec, policy.info_spec.
ValueError If batch_size is not either None or a python integer > 0.

action_input_spec Tuple (time_step_spec, policy_state_spec) for feeding action.

This describes the input of action in the SavedModel.

This may differ from the original policy if use_nest_path_signatures was enabled.

policy

policy_state_spec Spec that describes the output of get_initial_state in the SavedModel.

This may differ from the original policy if use_nest_path_signatures was enabled.

policy_step_spec Spec that describes the output of action in the SavedModel.

This may differ from the original policy if use_nest_path_signatures was enabled.

signatures Get the (flat) signatures used when exporting the SavedModel.

Methods

get_metadata

View source

Returns the metadata of the policy.

Returns
An a dictionary of tf.Variable.

get_train_step

View source

Returns the train step of the policy.

Returns
An integer.

register_concrete_function

View source

Registers a function into the saved model.

This gives you the flexibility to register any kind of polymorphic function by creating the concrete function that you wish to register.

Args
name Name of the attribute to use for the saved fn.
fn Function to register. Must be a callable following the input_spec as a single parameter.
assets Any extra checkpoint dependencies that must be captured in the module. Note variables are automatically captured.

register_function

View source

Registers a function into the saved model.

Args
name Name of the attribute to use for the saved fn.
fn Function to register. Must be a callable following the input_spec as a single parameter.
input_spec A nest of tf.TypeSpec representing the time_steps. Provided by the user.
outer_dims The outer dimensions the saved fn will process at a time. By default a batch dimension is added to the input_spec.

save

View source

Save the policy to the given export_dir.

Args
export_dir Directory to save the policy to.
options Optional tf.saved_model.SaveOptions object.

save_checkpoint

View source

Saves the policy as a checkpoint to the given export_dir.

This will only work with checkpoints generated in TF2.x.

For the checkpoint to be useful users should first call save to generate a saved_model of the policy. Checkpoints can then be used to update the policy without having to reload the saved_model, or saving multiple copies of the saved_model.pb file.

The checkpoint is always created in the sub-directory 'variables/' and the checkpoint file prefix used is 'variables'. The checkpoint files are as follows:

  • export_dir/variables/variables.index
  • export_dir/variables/variables-xxxxx-of-xxxxx

This makes the files compatible with the checkpoint part of full saved models, which enables you to load a saved model made up from the graph part of a full saved model and the variables part of a checkpoint.

Args
export_dir Directory to save the checkpoint to.
options Optional tf.train.CheckpointOptions object.