|View source on GitHub|
PolicySaver allows you to save a
tf_agents.policies.policy_saver.PolicySaver( policy, batch_size=None, use_nest_path_signatures=True, seed=None, train_step=None, input_fn_and_spec=None )
Used in the notebooks
|Used in the tutorials|
save() method exports a saved model to the requested export location.
The SavedModel that is exported can be loaded via
tf.saved_model.load in TF2). It
will have available signatures (concrete functions):
model_variables is also available when the saved_model is
loaded which gives access to model variables in order to update them if
my_policy = agent.collect_policy saver = PolicySaver(my_policy, batch_size=None) for i in range(...): agent.train(...) if i % 100 == 0: saver.save('policy_%d' % global_step)
To load and use the saved policy directly:
saved_policy = tf.compat.v2.saved_model.load('policy_0') policy_state = saved_policy.get_initial_state(batch_size=3) time_step = ... while True: policy_step = saved_policy.action(time_step, policy_state) policy_state = policy_step.state time_step = f(policy_step.action) ...
If using the flattened (signature) version, you will be limited to using dicts keyed by the specs' name fields.
saved_policy = tf.compat.v2.saved_model.load('policy_0') get_initial_state_fn = saved_policy.signatures['get_initial_state'] action_fn = saved_policy.signatures['action'] policy_state_dict = get_initial_state_fn(batch_size=3) time_step_dict = ... while True: time_step_state = dict(time_step_dict) time_step_state.update(policy_state_dict) policy_step_dict = action_fn(time_step_state) policy_state_dict = extract_policy_state_fields(policy_step_dict) action_dict = extract_action_fields(policy_step_dict) time_step_dict = f(action_dict) ...
policy: A TF Policy.
batch_size: The number of batch entries the policy will process at a time. This must be either
None(unknown batch size) or a python integer.
use_nest_path_signatures: SavedModel spec signatures will be created based on the sructure of the specs. Otherwise all specs must have unique names.
seed: Random seed for the
policy.actioncall, if any (this should usually be
None, except for testing).
train_step: Variable holding the train step for the policy. The value saved will be set at the time
saver.saveis called. If not provided, train_step defaults to -1.
(input_fn, tensor_spec)tuple where input_fn is a function that takes inputs according to tensor_spec and converts them to the
(time_step, policy_state)tuple that is used as the input to the action_fn. When
tensor_specis the input for the action signature. When
input_fn_and_spec is None, the action signature takes as input
policyis not an instance of TFPolicy.
ValueError: If use_nest_path_signatures is not used and any of the following
policyspecs are missing names, or the names collide:
batch_sizeis not either
Noneor a python integer > 0.
Returns the train step of the policy.
save( export_dir )
Save the policy to the given
save_checkpoint( export_dir )
Saves the policy as a checkpoint to the given
This will only work with checkpoints generated in TF2.x.
For the checkpoint to be useful users should first call
save to generate a
saved_model of the policy. Checkpoints can then be used to update the policy
without having to reload the saved_model, or saving multiple copies of the
The checkpoint is always created in the sub-directory 'variables/' and the checkpoint file prefix used is 'variables'. The checkpoint files are as follows:
This makes the files compatible with the checkpoint part of full saved models, which enables you to load a saved model made up from the graph part of a full saved model and the variables part of a checkpoint.
export_dir: Directory to save the checkpoint to.