tf_agents.agents.data_converter.AsTransition

Class that validates and converts other data types to Transition.

Note that validation and conversion allows values to contain dictionaries with extra keys as compared to the specs in the data context. These additional entries / observations are ignored and dropped during conversion.

This non-strict checking allows users to provide additional info and observation keys at input without having to manually prune them before converting.

data_context An instance of DataContext, typically accessed from the TFAgent.data_context property.
squeeze_time_dim Whether to emit a transition without time dimensions. If True, incoming trajectories are expected to have a time dimension of exactly 2, and emitted Transitions will have no time dimensions.
prepend_t0_to_next_time_step Whether to add t0 to next_time_step. This option is useful when using sequential model and can allow target network be able to take more information. Resulting shape of time_step.observation is [B, T, ...] and resulting shape of next_time_step.observation is [B, T+1, ...].

Methods

__call__

View source

Converts value to a Transition. Performs data validation and pruning.

  • If value is already a Transition, only validation is performed.
  • If value is a Trajectory and squeeze_time_dim = True then value it must have tensors with shape [B, T=2] outer dims. This is converted to a Transition object without a time dimension.
  • If value is a Trajectory with tensors containing a time dimension having T != 2, a ValueError is raised.

Args
value A Trajectory or Transition object to convert.

Returns
A validated and pruned Transition. If squeeze_time_dim = True, the resulting Transition has tensors with shape [B, ...]. Otherwise, the tensors will have shape [B, T - 1, ...].

Raises
TypeError If value is not one of Trajectory or Transition.
ValueError If value has structure that doesn't match the converter's spec.
TypeError If value has a structure that doesn't match the converter's spec.
ValueError If squeeze_time_dim=True and value is a Trajectory with a time dimension having value other than T=2.