Exposes a numpy API for TF policies in Eager mode.
Inherits From: PyTFEagerPolicyBase
, PyPolicy
tf_agents.policies.py_tf_eager_policy.PyTFEagerPolicy(
policy: tf_agents.policies.tf_policy.TFPolicy
,
use_tf_function: bool = False
)
Used in the notebooks
Args |
time_step_spec
|
A TimeStep ArraySpec of the expected time_steps. Usually
provided by the user to the subclass.
|
action_spec
|
A nest of BoundedArraySpec representing the actions. Usually
provided by the user to the subclass.
|
policy_state_spec
|
A nest of ArraySpec representing the policy state.
Provided by the subclass, not directly by the user.
|
info_spec
|
A nest of ArraySpec representing the policy info. Provided by
the subclass, not directly by the user.
|
observation_and_action_constraint_splitter
|
A function used to process
observations with action constraints. These constraints can indicate,
for example, a mask of valid/invalid actions for a given state of the
environment. The function takes in a full observation and returns a
tuple consisting of 1) the part of the observation intended as input to
the network and 2) the constraint. An example
observation_and_action_constraint_splitter could be as simple as: def observation_and_action_constraint_splitter(observation): return
observation['network_input'], observation['constraint']
Note: when using observation_and_action_constraint_splitter , make
sure the provided q_network is compatible with the network-specific
half of the output of the
observation_and_action_constraint_splitter . In particular,
observation_and_action_constraint_splitter will be called on the
observation before passing to the network. If
observation_and_action_constraint_splitter is None, action
constraints are not applied.
|
Attributes |
action_spec
|
Describes the ArraySpecs of the np.Array returned by action() .
action can be a single np.Array, or a nested dict, list or tuple of
np.Array.
|
collect_data_spec
|
Describes the data collected when using this policy with an environment.
|
info_spec
|
Describes the Arrays emitted as info by action() .
|
observation_and_action_constraint_splitter
|
|
policy_state_spec
|
Describes the arrays expected by functions with policy_state as input.
|
policy_step_spec
|
Describes the output of action() .
|
time_step_spec
|
Describes the TimeStep np.Arrays expected by action(time_step) .
|
trajectory_spec
|
Describes the data collected when using this policy with an environment.
|
Methods
action
View source
action(
time_step: tf_agents.trajectories.time_step.TimeStep
,
policy_state: tf_agents.typing.types.NestedArray
= ()
) -> tf_agents.trajectories.policy_step.PolicyStep
Generates next action given the time_step and policy_state.
Args |
time_step
|
A TimeStep tuple corresponding to time_step_spec() .
|
policy_state
|
An optional previous policy_state.
|
Returns |
A PolicyStep named tuple containing:
action : A nest of action Arrays matching the action_spec() .
state : A nest of policy states to be fed into the next call to action.
info : Optional side information such as action log probabilities.
|
get_initial_state
View source
get_initial_state(
batch_size: Optional[int] = None
) -> tf_agents.typing.types.NestedArray
Returns an initial state usable by the policy.
Args |
batch_size
|
An optional batch size.
|
Returns |
An initial policy state.
|
variables
View source
variables()