tf_agents.trajectories.to_n_step_transition
Stay organized with collections
Save and categorize content based on your preferences.
Create an n-step transition from a trajectory with T=N + 1
frames.
tf_agents . trajectories . to_n_step_transition (
trajectory : tf_agents . trajectories . Trajectory
,
gamma : tf_agents . typing . types . Float
) -> tf_agents . trajectories . Transition
Note: Tensors of trajectory
are sliced along their second (time
)
dimension, to pull out the appropriate fields for the n-step transitions.
The output transition's next_time_step.{reward, discount}
will contain
N-step discounted reward and discount values calculated as:
next_time_step . reward = r_t +
g ^ { 1 } * d_t * r_ { t + 1 } +
g ^ { 2 } * d_t * d_ { t + 1 } * r_ { t + 2 } +
g ^ { 3 } * d_t * d_ { t + 1 } * d_ { t + 2 } * r_ { t + 3 } +
...
g ^ { N - 1 } * d_t * ... * d_ { t + N - 2 } * r_ { t + N - 1 }
next_time_step . discount = g ^ { N - 1 } * d_t * d_ { t + 1 } * ... * d_ { t + N - 1 }
In python notation:
discount = gamma ** ( N - 1 ) * reduce_prod ( trajectory . discount [:, : - 1 ])
reward = discounted_return (
rewards = trajectory . reward [:, : - 1 ],
discounts = gamma * trajectory . discount [:, : - 1 ])
When trajectory.discount[:, :-1]
is an all-ones tensor, this is equivalent
to:
next_time_step . discount = (
gamma ** ( N - 1 ) * tf . ones_like ( trajectory . discount [:, 0 ]))
next_time_step . reward = (
sum_ { n = 0 } ^ { N - 1 } gamma ** n * trajectory . reward [:, n ])
Args
trajectory
An instance of Trajectory
. The tensors in Trajectory must have
shape [B, T, ...]
. discount
is assumed to be a scalar float, hence
the shape of trajectory.discount
must be [B, T]
.
gamma
A floating point scalar; the discount factor.
Returns
An N-step Transition
where N = T - 1
. The reward and discount in
time_step.{reward, discount}
are NaN. The n-step discounted reward
and final discount are stored in next_time_step.{reward, discount}
.
All tensors in the Transition
have shape [B, ...]
(no time dimension).
Raises
ValueError
if discount.shape.rank != 2
.
ValueError
if discount.shape[1] < 2
.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License , and code samples are licensed under the Apache 2.0 License . For details, see the Google Developers Site Policies . Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[{
"type": "thumb-down",
"id": "missingTheInformationINeed",
"label":"Missing the information I need"
},{
"type": "thumb-down",
"id": "tooComplicatedTooManySteps",
"label":"Too complicated / too many steps"
},{
"type": "thumb-down",
"id": "outOfDate",
"label":"Out of date"
},{
"type": "thumb-down",
"id": "samplesCodeIssue",
"label":"Samples / code issue"
},{
"type": "thumb-down",
"id": "otherDown",
"label":"Other"
}]
[{
"type": "thumb-up",
"id": "easyToUnderstand",
"label":"Easy to understand"
},{
"type": "thumb-up",
"id": "solvedMyProblem",
"label":"Solved my problem"
},{
"type": "thumb-up",
"id": "otherUp",
"label":"Other"
}]
{"lastModified": "Last updated 2024-04-26 UTC."}
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[]]