tf_agents.train.triggers.PolicySavedModelTrigger

Triggers saves policy checkpoints an agent's policy.

Inherits From: IntervalTrigger

Used in the notebooks

Used in the tutorials

On construction this trigger will generate a saved_model for a: greedy_policy, a collect_policy, and a raw_policy. When triggered a checkpoint will be saved which can be used to updated any of the saved_model policies.

saved_model_dir Base dir where checkpoints will be saved.
agent Agent to extract policies from.
train_step tf.Variable which keeps track of the number of train steps.
interval How often, in train_steps, the trigger will save. Note that as long as the >= interval number of steps have passed since the last trigger, the event gets triggered. The current value is not necessarily interval steps away from the last triggered value.
async_saving If True saving will be done asynchronously in a separate thread. Note if this is on the variable values in the saved checkpoints/models are not deterministic.
metadata_metrics A dictionary of metrics, whose result() method returns a scalar to be saved along with the policy. Currently only supported when async_saving is False.
start Initial value for the trigger passed directly to the base class. It helps control from which train step the weigts of the model are saved.
extra_concrete_functions Optional sequence of extra concrete functions to register in the policy savers. The sequence should consist of tuples with string name for the function and the tf.function to register. Note this does not support adding extra assets.
batch_size The number of batch entries the policy will process at a time. This must be either None (unknown batch size) or a python integer.
use_nest_path_signatures SavedModel spec signatures will be created based on the sructure of the specs. Otherwise all specs must have unique names.
save_greedy_policy Disable when an agent's policy distribution method does not support mode.
save_collect_policy Disable when not saving collect policy.
input_fn_and_spec A (input_fn, tensor_spec) tuple where input_fn is a function that takes inputs according to tensor_spec and converts them to the (time_step, policy_state) tuple that is used as the input to the action_fn. When input_fn_and_spec is set, tensor_spec is the input for the action signature. When input_fn_and_spec is None, the action signature takes as input (time_step, policy_state).

Methods

reset

View source

Resets the trigger interval.

set_start

View source

__call__

View source

Maybe trigger the event based on the interval.

Args
value the value for triggering.
force_trigger If True, the trigger will be forced triggered unless the last trigger value is equal to value.