tf_agents.environments.TrajectoryReplay
Stay organized with collections
Save and categorize content based on your preferences.
A helper that replays a policy against given Trajectory
observations.
tf_agents.environments.TrajectoryReplay(
policy, time_major=False
)
Args |
policy
|
A tf_policy.TFPolicy policy.
|
time_major
|
If True , the tensors in trajectory passed to method run
are assumed to have shape [time, batch, ...] . Otherwise (default)
they are assumed to have shape [batch, time, ...] .
|
Raises |
ValueError
|
If policy is not an instance of tf_policy.TFPolicy.
|
Methods
run
View source
run(
trajectory, policy_state=None
)
Apply the policy to trajectory steps and store actions/info.
If self.time_major == True
, the tensors in trajectory
are assumed to
have shape [time, batch, ...]
. Otherwise they are assumed to
have shape [batch, time, ...]
.
Args |
trajectory
|
The Trajectory to run against. If the replay class was
created with time_major=True , then the tensors in trajectory must be
shaped [time, batch, ...] . Otherwise they must be shaped [batch,
time, ...] .
|
policy_state
|
(optional) A nest Tensor with initial step policy state.
|
Returns |
output_actions
|
A nest of the actions that the policy took.
If the replay class was created with time_major=True , then
the tensors here will be shaped [time, batch, ...] . Otherwise
they'll be shaped [batch, time, ...] .
|
output_policy_info
|
A nest of the policy info that the policy emitted.
If the replay class was created with time_major=True , then
the tensors here will be shaped [time, batch, ...] . Otherwise
they'll be shaped [batch, time, ...] .
|
policy_state
|
A nest Tensor with final step policy state.
|
Raises |
TypeError
|
If policy_state structure doesn't match
self.policy.policy_state_spec , or trajectory structure doesn't
match self.policy.trajectory_spec .
|
ValueError
|
If policy_state doesn't match
self.policy.policy_state_spec , or trajectory structure doesn't
match self.policy.trajectory_spec .
|
ValueError
|
If trajectory lacks two outer dims.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[{
"type": "thumb-down",
"id": "missingTheInformationINeed",
"label":"Missing the information I need"
},{
"type": "thumb-down",
"id": "tooComplicatedTooManySteps",
"label":"Too complicated / too many steps"
},{
"type": "thumb-down",
"id": "outOfDate",
"label":"Out of date"
},{
"type": "thumb-down",
"id": "samplesCodeIssue",
"label":"Samples / code issue"
},{
"type": "thumb-down",
"id": "otherDown",
"label":"Other"
}]
[{
"type": "thumb-up",
"id": "easyToUnderstand",
"label":"Easy to understand"
},{
"type": "thumb-up",
"id": "solvedMyProblem",
"label":"Solved my problem"
},{
"type": "thumb-up",
"id": "otherUp",
"label":"Other"
}]
{"lastModified": "Last updated 2024-04-26 UTC."}
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[]]