tf_agents.policies.policy_saver.PolicySaver

View source on GitHub

A PolicySaver allows you to save a tf_policy.Policy to SavedModel.

Used in the notebooks

Used in the tutorials

The save() method exports a saved model to the requested export location. The SavedModel that is exported can be loaded via tf.compat.v2.saved_model.load (or tf.saved_model.load in TF2). It will have available signatures (concrete functions): action, get_initial_state, `get_train_step.

The attribute model_variables is also available when the saved_model is loaded which gives access to model variables in order to update them if needed.

Usage:


my_policy = agent.collect_policy
saver = PolicySaver(my_policy, batch_size=None)

for i in range(...):
  agent.train(...)
  if i % 100 == 0:
    saver.save('policy_%d' % global_step)

To load and use the saved policy directly:

saved_policy = tf.compat.v2.saved_model.load('policy_0')
policy_state = saved_policy.get_initial_state(batch_size=3)
time_step = ...
while True:
  policy_step = saved_policy.action(time_step, policy_state)
  policy_state = policy_step.state
  time_step = f(policy_step.action)
  ...

If using the flattened (signature) version, you will be limited to using dicts keyed by the specs' name fields.

saved_policy = tf.compat.v2.saved_model.load('policy_0')
get_initial_state_fn = saved_policy.signatures['get_initial_state']
action_fn = saved_policy.signatures['action']

policy_state_dict = get_initial_state_fn(batch_size=3)
time_step_dict = ...
while True:
  time_step_state = dict(time_step_dict)
  time_step_state.update(policy_state_dict)
  policy_step_dict = action_fn(time_step_state)
  policy_state_dict = extract_policy_state_fields(policy_step_dict)
  action_dict = extract_action_fields(policy_step_dict)
  time_step_dict = f(action_dict)
  ...

policy A TF Policy.
batch_size The number of batch entries the policy will process at a time. This must be either None (unknown batch size) or a python integer.
use_nest_path_signatures SavedModel spec signatures will be created based on the sructure of the specs. Otherwise all specs must have unique names.
seed Random seed for the policy.action call, if any (this should usually be None, except for testing).
train_step Variable holding the train step for the policy. The value saved will be set at the time saver.save is called. If not provided, train_step defaults to -1.
input_fn_and_spec A (input_fn, tensor_spec) tuple where input_fn is a function that takes inputs according to tensor_spec and converts them to the (time_step, policy_state) tuple that is used as the input to the action_fn. When input_fn_and_spec is set, tensor_spec is the input for the action signature. When input_fn_and_spec is None, the action signature takes as input (time_step, policy_state).

TypeError If policy is not an instance of TFPolicy.
ValueError If use_nest_path_signatures is not used and any of the following policy specs are missing names, or the names collide: policy.time_step_spec, policy.action_spec, policy.policy_state_spec, policy.info_spec.
ValueError If batch_size is not either None or a python integer > 0.

Methods

save

View source

Save the policy to the given export_dir.

save_checkpoint

View source

Saves the policy as a checkpoint to the given `export_dir.

Args
export_dir Directory to save the checkpoint to.