![]() |
Exposes a numpy API for saved_model policies in Eager mode.
Inherits From: PyTFEagerPolicyBase
, PyPolicy
tf_agents.policies.py_tf_eager_policy.SavedModelPyTFEagerPolicy(
model_path: Text,
time_step_spec: Optional[tf_agents.trajectories.time_step.TimeStep
] = None,
action_spec: Optional[types.NestedTensorSpec] = None,
policy_state_spec: tf_agents.typing.types.NestedTensorSpec
= (),
info_spec: tf_agents.typing.types.NestedTensorSpec
= (),
load_specs_from_pbtxt: bool = False
)
Used in the notebooks
Used in the tutorials |
---|
Args | |
---|---|
model_path
|
Path to a saved_model generated by the policy_saver .
|
time_step_spec
|
Optional nested structure of ArraySpecs describing the
policy's time_step_spec . This is not used by the
SavedModelPyTFEagerPolicy, but may be accessed by other objects as it is
part of the public policy API.
|
action_spec
|
Optional nested structure of ArraySpecs describing the
policy's action_spec . This is not used by the
SavedModelPyTFEagerPolicy, but may be accessed by other objects as it is
part of the public policy API.
|
policy_state_spec
|
Optional nested structure of ArraySpecs describing
the policy's policy_state_spec . This is not used by the
SavedModelPyTFEagerPolicy, but may be accessed by other objects as it is
part of the public policy API.
|
info_spec
|
Optional nested structure of ArraySpecs describing the
policy's info_spec . This is not used by the SavedModelPyTFEagerPolicy,
but may be accessed by other objects as it is part of the public policy
API.
|
load_specs_from_pbtxt
|
If True the specs will be loaded from the proto
file generated by the policy_saver .
|
Attributes | |
---|---|
action_spec
|
Describes the ArraySpecs of the np.Array returned by action() .
|
collect_data_spec
|
Describes the data collected when using this policy with an environment. |
info_spec
|
Describes the Arrays emitted as info by action() .
|
observation_and_action_constraint_splitter
|
|
policy_state_spec
|
Describes the arrays expected by functions with policy_state as input.
|
policy_step_spec
|
Describes the output of action() .
|
time_step_spec
|
Describes the TimeStep np.Arrays expected by action(time_step) .
|
trajectory_spec
|
Describes the data collected when using this policy with an environment. |
Methods
action
action(
time_step: tf_agents.trajectories.time_step.TimeStep
,
policy_state: tf_agents.typing.types.NestedArray
= ()
) -> tf_agents.trajectories.policy_step.PolicyStep
Generates next action given the time_step and policy_state.
Args | |
---|---|
time_step
|
A TimeStep tuple corresponding to time_step_spec() .
|
policy_state
|
An optional previous policy_state. |
Returns | |
---|---|
A PolicyStep named tuple containing:
action : A nest of action Arrays matching the action_spec() .
state : A nest of policy states to be fed into the next call to action.
info : Optional side information such as action log probabilities.
|
get_initial_state
get_initial_state(
batch_size: Optional[int] = None
) -> tf_agents.typing.types.NestedArray
Returns an initial state usable by the policy.
Args | |
---|---|
batch_size
|
An optional batch size. |
Returns | |
---|---|
An initial policy state. |
get_metadata
get_metadata()
Returns the metadata of the saved model.
get_train_step
get_train_step() -> tf_agents.typing.types.Int
Returns the training global step of the saved model.
update_from_checkpoint
update_from_checkpoint(
checkpoint_path: Text
)
Allows users to update saved_model variables directly from a checkpoint.
checkpoint_path
is a path that was passed to either PolicySaver.save()
or PolicySaver.save_checkpoint()
. The policy looks for set of checkpoint
files with the file prefix `
Args | |
---|---|
checkpoint_path
|
Path to the checkpoint to restore and use to udpate this policy. |
variables
variables()