View on Run in Google Colab View source on GitHub Download notebook


In this colab we will cover how to define custom networks for your agents. The networks help us define the model that is trained by agents. In TF-Agents you will find several different types of networks which are useful across agents:

Main Networks

  • QNetwork: Used in Qlearning for environments with discrete actions, this network maps an observation to value estimates for each possible action.
  • CriticNetworks: Also referred to as ValueNetworks in literature, learns to estimate some version of a Value function mapping some state into an estimate for the expected return of a policy. These networks estimate how good the state the agent is currently in is.
  • ActorNetworks: Learn a mapping from observations to actions. These networks are usually used by our policies to generate actions.
  • ActorDistributionNetworks: Similar to ActorNetworks but these generate a distribution which a policy can then sample to generate actions.

Helper Networks

  • EncodingNetwork: Allows users to easily define a mapping of pre-processing layers to apply to a network's input.
  • DynamicUnrollLayer: Automatically resets the network's state on episode boundaries as it is applied over a time sequence.
  • ProjectionNetwork: Networks like CategoricalProjectionNetwork or NormalProjectionNetwork take inputs and generate the required parameters to generate Categorical, or Normal distributions.

All examples in TF-Agents come with pre-configured networks. However these networks are not setup to handle complex observations.

If you have an environment which exposes more than one observation/action and you need to customize your networks then this tutorial is for you!


If you haven't installed tf-agents yet, run:

pip install tf-agents
pip install tf-keras
import os
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import tensorflow as tf
import numpy as np

from tf_agents.environments import random_py_environment
from tf_agents.environments import tf_py_environment
from tf_agents.networks import encoding_network
from tf_agents.networks import network
from tf_agents.networks import utils
from tf_agents.specs import array_spec
from tf_agents.utils import common as common_utils
from tf_agents.utils import nest_utils

Defining Networks

Network API

In TF-Agents we subclass from Keras Networks. With it we can:

  • Simplify copy operations required when creating target networks.
  • Perform automatic variable creation when calling network.variables().
  • Validate inputs based on network input_specs.


As mentioned above the EncodingNetwork allows us to easily define a mapping of pre-processing layers to apply to a network's input to generate some encoding.

The EncodingNetwork is composed of the following mostly optional layers:

  • Preprocessing layers
  • Preprocessing combiner
  • Conv2D
  • Flatten
  • Dense

The special thing about encoding networks is that input preprocessing is applied. Input preprocessing is possible via preprocessing_layers and preprocessing_combiner layers. Each of these can be specified as a nested structure. If the preprocessing_layers nest is shallower than input_tensor_spec, then the layers will get the subnests. For example, if:

input_tensor_spec = ([TensorSpec(3)] * 2, [TensorSpec(3)] * 5)
preprocessing_layers = (Layer1(), Layer2())

then preprocessing will call:

preprocessed = [preprocessing_layers[0](observations[0]),

However if

preprocessing_layers = ([Layer1() for _ in range(2)],
                        [Layer2() for _ in range(5)])

then preprocessing will call:

preprocessed = [
  layer(obs) for layer, obs in zip(flatten(preprocessing_layers),

Custom Networks

To create your own networks you will only have to override the __init__ and call methods. Let's create a custom network using what we learned about EncodingNetworks to create an ActorNetwork that takes observations which contain an image and a vector.

class ActorNetwork(network.Network):

  def __init__(self,
               fc_layer_params=(75, 40),
    super(ActorNetwork, self).__init__(
        input_tensor_spec=observation_spec, state_spec=(), name=name)

    # For simplicity we will only support a single action float output.
    self._action_spec = action_spec
    flat_action_spec = tf.nest.flatten(action_spec)
    if len(flat_action_spec) > 1:
      raise ValueError('Only a single action is supported by this network')
    self._single_action_spec = flat_action_spec[0]
    if self._single_action_spec.dtype not in [tf.float32, tf.float64]:
      raise ValueError('Only float actions are supported by this network.')

    kernel_initializer = tf.keras.initializers.VarianceScaling(
        scale=1. / 3., mode='fan_in', distribution='uniform')
    self._encoder = encoding_network.EncodingNetwork(

    initializer = tf.keras.initializers.RandomUniform(
        minval=-0.003, maxval=0.003)

    self._action_projection_layer = tf.keras.layers.Dense(

  def call(self, observations, step_type=(), network_state=()):
    outer_rank = nest_utils.get_outer_rank(observations, self.input_tensor_spec)
    # We use batch_squash here in case the observations have a time sequence
    # compoment.
    batch_squash = utils.BatchSquash(outer_rank)
    observations = tf.nest.map_structure(batch_squash.flatten, observations)

    state, network_state = self._encoder(
        observations, step_type=step_type, network_state=network_state)
    actions = self._action_projection_layer(state)
    actions = common_utils.scale_to_spec(actions, self._single_action_spec)
    actions = batch_squash.unflatten(actions)
    return tf.nest.pack_sequence_as(self._action_spec, [actions]), network_state

Let's create a RandomPyEnvironment to generate structured observations and validate our implementation.

action_spec = array_spec.BoundedArraySpec((3,), np.float32, minimum=0, maximum=10)
observation_spec =  {
    'image': array_spec.BoundedArraySpec((16, 16, 3), np.float32, minimum=0,
    'vector': array_spec.BoundedArraySpec((5,), np.float32, minimum=-100,

random_env = random_py_environment.RandomPyEnvironment(observation_spec, action_spec=action_spec)

# Convert the environment to a TFEnv to generate tensors.
tf_env = tf_py_environment.TFPyEnvironment(random_env)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_agents/specs/ RuntimeWarning: invalid value encountered in cast
  self._minimum[self._minimum == -np.inf] = low
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_agents/specs/ RuntimeWarning: invalid value encountered in cast
  self._minimum[self._minimum == np.inf] = high
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_agents/specs/ RuntimeWarning: invalid value encountered in cast
  self._maximum[self._maximum == -np.inf] = low
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_agents/specs/ RuntimeWarning: invalid value encountered in cast
  self._maximum[self._maximum == np.inf] = high

Since we've defined the observations to be a dict we need to create preprocessing layers to handle these.

preprocessing_layers = {
    'image': tf.keras.models.Sequential([tf.keras.layers.Conv2D(8, 4),
    'vector': tf.keras.layers.Dense(5)
preprocessing_combiner = tf.keras.layers.Concatenate(axis=-1)
actor = ActorNetwork(tf_env.observation_spec(), 

Now that we have the actor network we can process observations from the environment.

time_step = tf_env.reset()
actor(time_step.observation, time_step.step_type)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tf_keras/src/initializers/ UserWarning: The initializer VarianceScaling is unseeded and being called multiple times, which will return identical values each time (even if the initializer is unseeded). Please update your code to provide a seed to the initializer, or avoid using the same initializer instance more than once.
(<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[5.8357787, 4.3249702, 3.6428978]], dtype=float32)>,

This same strategy can be used to customize any of the main networks used by the agents. You can define whatever preprocessing and connect it to the rest of the network. As you define your own custom make sure the output layer definitions of the network match.