Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.policies.policy_saver.PolicySaver

View source on GitHub

A PolicySaver allows you to save a tf_policy.Policy to SavedModel.

tf_agents.policies.policy_saver.PolicySaver(
    policy, batch_size=None, use_nest_path_signatures=True, seed=None,
    train_step=None, input_fn_and_spec=None
)

Used in the notebooks

Used in the tutorials

The save() method exports a saved model to the requested export location. The SavedModel that is exported can be loaded via tf.compat.v2.saved_model.load (or tf.saved_model.load in TF2). It will have available signatures (concrete functions): action, get_initial_state, `get_train_step.

The attribute model_variables is also available when the saved_model is loaded which gives access to model variables in order to update them if needed.

Usage:


my_policy = agent.collect_policy
saver = PolicySaver(my_policy, batch_size=None)

for i in range(...):
  agent.train(...)
  if i % 100 == 0:
    saver.save('policy_%d' % global_step)

To load and use the saved policy directly:

saved_policy = tf.compat.v2.saved_model.load('policy_0')
policy_state = saved_policy.get_initial_state(batch_size=3)
time_step = ...
while True:
  policy_step = saved_policy.action(time_step, policy_state)
  policy_state = policy_step.state
  time_step = f(policy_step.action)
  ...

If using the flattened (signature) version, you will be limited to using dicts keyed by the specs' name fields.

saved_policy = tf.compat.v2.saved_model.load('policy_0')
get_initial_state_fn = saved_policy.signatures['get_initial_state']
action_fn = saved_policy.signatures['action']

policy_state_dict = get_initial_state_fn(batch_size=3)
time_step_dict = ...
while True:
  time_step_state = dict(time_step_dict)
  time_step_state.update(policy_state_dict)
  policy_step_dict = action_fn(time_step_state)
  policy_state_dict = extract_policy_state_fields(policy_step_dict)
  action_dict = extract_action_fields(policy_step_dict)
  time_step_dict = f(action_dict)
  ...

Args:

  • policy: A TF Policy.
  • batch_size: The number of batch entries the policy will process at a time. This must be either None (unknown batch size) or a python integer.
  • use_nest_path_signatures: SavedModel spec signatures will be created based on the sructure of the specs. Otherwise all specs must have unique names.
  • seed: Random seed for the policy.action call, if any (this should usually be None, except for testing).
  • train_step: Variable holding the train step for the policy. The value saved will be set at the time saver.save is called. If not provided, train_step defaults to -1.
  • input_fn_and_spec: A (input_fn, tensor_spec) tuple where input_fn is a function that takes inputs according to tensor_spec and converts them to the (time_step, policy_state) tuple that is used as the input to the action_fn. When input_fn_and_spec is set, tensor_spec is the input for the action signature. When input_fn_and_spec is None, the action signature takes as input (time_step, policy_state).

Raises:

  • TypeError: If policy is not an instance of TFPolicy.
  • ValueError: If use_nest_path_signatures is not used and any of the following policy specs are missing names, or the names collide: policy.time_step_spec, policy.action_spec, policy.policy_state_spec, policy.info_spec.
  • ValueError: If batch_size is not either None or a python integer > 0.

Methods

get_train_step

View source

get_train_step()

Returns the train step of the policy.

Returns:

An integer.

save

View source

save(
    export_dir
)

Save the policy to the given export_dir.

save_checkpoint

View source

save_checkpoint(
    export_dir
)

Saves the policy as a checkpoint to the given export_dir.

This will only work with checkpoints generated in TF2.x.

For the checkpoint to be useful users should first call save to generate a saved_model of the policy. Checkpoints can then be used to update the policy without having to reload the saved_model, or saving multiple copies of the saved_model.pb file.

The checkpoint is always created in the sub-directory 'variables/' and the checkpoint file prefix used is 'variables'. The checkpoint files are as follows:

  • export_dir/variables/variables.index
  • export_dir/variables/variables-xxxxx-of-xxxxx

This makes the files compatible with the checkpoint part of full saved models, which enables you to load a saved model made up from the graph part of a full saved model and the variables part of a checkpoint.

Args:

  • export_dir: Directory to save the checkpoint to.