Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tf_agents.policies.py_tf_policy.PyTFPolicy

View source on GitHub

Exposes a Python policy as wrapper over a TF Policy.

Inherits From: Base, SessionUser

tf_agents.policies.py_tf_policy.PyTFPolicy(
    policy, batch_size=None, seed=None
)

Args:

  • policy: A TF Policy implementing tf_policy.Base.
  • batch_size: (deprecated)
  • seed: Seed to use if policy performs random actions (optional).

Attributes:

  • action_spec: Describes the ArraySpecs of the np.Array returned by action().

    action can be a single np.Array, or a nested dict, list or tuple of np.Array.

  • info_spec: Describes the Arrays emitted as info by action().

  • observation_and_action_constraint_splitter

  • policy_state_spec: Describes the arrays expected by functions with policy_state as input.

  • policy_step_spec: Describes the output of action().

  • session: Returns the TensorFlow session-like object used by this object.

  • time_step_spec: Describes the TimeStep np.Arrays expected by action(time_step).

  • trajectory_spec: Describes the data collected when using this policy with an environment.

Methods

action

View source

action(
    time_step, policy_state=()
)

Generates next action given the time_step and policy_state.

Args:

  • time_step: A TimeStep tuple corresponding to time_step_spec().
  • policy_state: An optional previous policy_state.

Returns:

A PolicyStep named tuple containing: action: A nest of action Arrays matching the action_spec(). state: A nest of policy states to be fed into the next call to action. info: Optional side information such as action log probabilities.

get_initial_state

View source

get_initial_state(
    batch_size=None
)

Returns an initial state usable by the policy.

Args:

  • batch_size: An optional batch size.

Returns:

An initial policy state.

initialize

View source

initialize(
    batch_size, graph=None
)

restore

View source

restore(
    policy_dir, graph=None, assert_consumed=True
)

Restores the policy from the checkpoint.

Args:

  • policy_dir: Directory with the checkpoint.
  • graph: A graph, inside which policy the is restored (optional).
  • assert_consumed: If true, contents of the checkpoint will be checked for a match against graph variables.

Returns:

  • step: Global step associated with the restored policy checkpoint.

Raises:

  • RuntimeError: if the policy is not initialized.
  • AssertionError: if the checkpoint contains variables which do not have matching names in the graph, and assert_consumed is set to True.

save

View source

save(
    policy_dir=None, graph=None
)