![]() |
Abstract base class for TF Policies.
tf_agents.policies.TFPolicy(
time_step_spec: tf_agents.trajectories.TimeStep
,
action_spec: tf_agents.typing.types.NestedTensorSpec
,
policy_state_spec: tf_agents.typing.types.NestedTensorSpec
= (),
info_spec: tf_agents.typing.types.NestedTensorSpec
= (),
clip: bool = True,
emit_log_probability: bool = False,
automatic_state_reset: bool = True,
observation_and_action_constraint_splitter: Optional[types.Splitter] = None,
validate_args: bool = True,
name: Optional[Text] = None
)
Used in the notebooks
Used in the tutorials |
---|
The Policy represents a mapping from time_steps
recieved from the
environment to actions
that can be applied to the environment.
Agents expose two policies. A policy
meant for deployment and evaluation,
and a collect_policy
for collecting data from the environment. The
collect_policy
is usually stochastic for exploring the environment better
and may log auxilliary information such as log probabilities required for
training as well. Policy
objects can also be created directly by the users
without using an Agent
.
The main methods of TFPolicy are:
action
: Maps atime_step
from the environment to an action.distribution
: Maps atime_step
to a distribution over actions.get_initial_state
: Generates the initial state for stateful policies, e.g. RNN/LSTM policies.
Example usage:
env = SomeTFEnvironment()
policy = TFRandomPolicy(env.time_step_spec(), env.action_spec())
# Or policy = agent.policy or agent.collect_policy
policy_state = policy.get_initial_state(env.batch_size)
time_step = env.reset()
while not time_step.is_last():
policy_step = policy.action(time_step, policy_state)
time_step = env.step(policy_step.action)
policy_state = policy_step.state
# policy_step.info may contain side info for logging, such as action log
# probabilities.
Policies can be saved to disk as SavedModels (see policy_saver.py and policy_loader.py) or as TF Checkpoints.
A PyTFEagerPolicy
can be used to wrap a TFPolicy
so that it works with
PyEnvironment
s.
For researchers, and those developing new Policies, the TFPolicy
base class
constructor also accept a validate_args
parameter. If False
, this
disables all spec structure, dtype, and shape checks in the public methods of
these classes. It allows algorithm developers to iterate and try different
input and output structures without worrying about overly restrictive
requirements, or input and output states being in a certain format. However,
disabling argument validation can make it very hard to identify structural
input or algorithmic errors; and should not be done for final, or
production-ready, Policies. In addition to having implementations that may
disagree with specs, this mean that the resulting Policy may no longer
interact well with other parts of TF-Agents. Examples include impedance
mismatches with Actor/Learner APIs, replay buffers, and the model export
functionality in `PolicySaver.
Args | |
---|---|
time_step_spec
|
A TimeStep spec of the expected time_steps. Usually
provided by the user to the subclass.
|
action_spec
|
A nest of BoundedTensorSpec representing the actions. Usually provided by the user to the subclass. |
policy_state_spec
|
A nest of TensorSpec representing the policy_state. Provided by the subclass, not directly by the user. |
info_spec
|
A nest of TensorSpec representing the policy info. Provided by the subclass, not directly by the user. |
clip
|
Whether to clip actions to spec before returning them. Default True. Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped continuous actions for training. |
emit_log_probability
|
Emit log-probabilities of actions, if supported. If True, policy_step.info will have CommonFields.LOG_PROBABILITY set. Please consult utility methods provided in policy_step for setting and retrieving these. When working with custom policies, either provide a dictionary info_spec or a namedtuple with the field 'log_probability'. |
automatic_state_reset
|
If True , then get_initial_policy_state is used
to clear state in action() and distribution() for for time steps
where time_step.is_first() .
|
observation_and_action_constraint_splitter
|
A function used to process
observations with action constraints. These constraints can indicate,
for example, a mask of valid/invalid actions for a given state of the
environment. The function takes in a full observation and returns a
tuple consisting of 1) the part of the observation intended as input to
the network and 2) the constraint. An example
observation_and_action_constraint_splitter could be as simple as: def observation_and_action_constraint_splitter(observation): return
observation['network_input'], observation['constraint']
Note: when using observation_and_action_constraint_splitter , make
sure the provided q_network is compatible with the network-specific
half of the output of the
observation_and_action_constraint_splitter . In particular,
observation_and_action_constraint_splitter will be called on the
observation before passing to the network. If
observation_and_action_constraint_splitter is None, action
constraints are not applied.
|
validate_args
|
Python bool. Whether to verify inputs to, and outputs of,
functions like action and distribution against spec structures,
dtypes, and shapes.
Research code may prefer to set this value to See also |
name
|
A name for this module. Defaults to the class name. |
Attributes | |
---|---|
action_spec
|
Describes the TensorSpecs of the Tensors expected by step(action) .
|
collect_data_spec
|
Describes the Tensors written when using this policy with an environment. |
emit_log_probability
|
Whether this policy instance emits log probabilities or not. |
info_spec
|
Describes the Tensors emitted as info by action and distribution .
|
observation_and_action_constraint_splitter
|
|
policy_state_spec
|
Describes the Tensors expected by step(_, policy_state) .
|
policy_step_spec
|
Describes the output of action() .
|
time_step_spec
|
Describes the TimeStep tensors returned by step() .
|
trajectory_spec
|
Describes the Tensors written when using this policy with an environment. |
validate_args
|
Whether action & distribution validate input and output args.
|
Methods
action
action(
time_step: tf_agents.trajectories.TimeStep
,
policy_state: tf_agents.typing.types.NestedTensor
= (),
seed: Optional[types.Seed] = None
) -> tf_agents.trajectories.PolicyStep
Generates next action given the time_step and policy_state.
Args | |
---|---|
time_step
|
A TimeStep tuple corresponding to time_step_spec() .
|
policy_state
|
A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. |
seed
|
Seed to use if action performs sampling (optional). |
Returns | |
---|---|
A PolicyStep named tuple containing:
action : An action Tensor matching the action_spec .
state : A policy state tensor to be fed into the next call to action.
info : Optional side information such as action log probabilities.
|
Raises | |
---|---|
RuntimeError
|
If subclass init didn't call super().init.
ValueError or TypeError: If validate_args is True and inputs or
outputs do not match time_step_spec , policy_state_spec ,
or policy_step_spec .
|
distribution
distribution(
time_step: tf_agents.trajectories.TimeStep
,
policy_state: tf_agents.typing.types.NestedTensor
= ()
) -> tf_agents.trajectories.PolicyStep
Generates the distribution over next actions given the time_step.
Args | |
---|---|
time_step
|
A TimeStep tuple corresponding to time_step_spec() .
|
policy_state
|
A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. |
Returns | |
---|---|
A PolicyStep named tuple containing:
|
Raises | |
---|---|
ValueError or TypeError: If validate_args is True and inputs or
outputs do not match time_step_spec , policy_state_spec ,
or policy_step_spec .
|
get_initial_state
get_initial_state(
batch_size: Optional[types.Int]
) -> tf_agents.typing.types.NestedTensor
Returns an initial state usable by the policy.
Args | |
---|---|
batch_size
|
Tensor or constant: size of the batch dimension. Can be None in which case no dimensions gets added. |
Returns | |
---|---|
A nested object of type policy_state containing properly
initialized Tensors.
|
update
update(
policy,
tau: float = 1.0,
tau_non_trainable: Optional[float] = None,
sort_variables_by_name: bool = False
) -> tf.Operation
Update the current policy with another policy.
This would include copying the variables from the other policy.
Args | |
---|---|
policy
|
Another policy it can update from. |
tau
|
A float scalar in [0, 1]. When tau is 1.0 (the default), we do a hard update. This is used for trainable variables. |
tau_non_trainable
|
A float scalar in [0, 1] for non_trainable variables. If None, will copy from tau. |
sort_variables_by_name
|
A bool, when True would sort the variables by name before doing the update. |
Returns | |
---|---|
An TF op to do the update. |