View source on GitHub |
Returns the global and per-arm context dimensions.
tf_agents.specs.bandit_spec_utils.get_context_dims_from_spec(
context_spec: tf_agents.typing.types.NestedTensorSpec
,
accepts_per_arm_features: bool
) -> Tuple[int, int]
If the policy accepts per-arm features, this function returns the tuple of the global and per-arm context dimension. Otherwise, it returns the (global) context dim and zero.
Args | |
---|---|
context_spec
|
A nest of tensor specs, containing the observation spec. |
accepts_per_arm_features
|
(bool) Whether the context_spec is for a policy that accepts per-arm features. |
Returns: A 2-tuple of ints, the global and per-arm context dimension. If the policy does not accept per-arm features, the per-arm context dim is 0.