Copyright 2018 The TF-Agents Authors.
|View on TensorFlow.org||Run in Google Colab||View source on GitHub||Download notebook|
Reinforcement learning algorithms use replay buffers to store trajectories of experience when executing a policy in an environment. During training, replay buffers are queried for a subset of the trajectories (either a sequential subset or a sample) to "replay" the agent's experience.
In this colab, we explore two types of replay buffers: python-backed and tensorflow-backed, sharing a common API. In the following sections, we describe the API, each of the buffer implementations and how to use them during data collection training.
Install tf-agents if you haven't already.
pip install -q tf-agents
pip install -q gym
from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf import numpy as np from tf_agents import specs from tf_agents.agents.dqn import dqn_agent from tf_agents.drivers import dynamic_step_driver from tf_agents.environments import suite_gym from tf_agents.environments import tf_py_environment from tf_agents.networks import q_network from tf_agents.replay_buffers import py_uniform_replay_buffer from tf_agents.replay_buffers import tf_uniform_replay_buffer from tf_agents.specs import tensor_spec from tf_agents.trajectories import time_step tf.compat.v1.enable_v2_behavior()
Replay Buffer API
The Replay Buffer class has the following definition and methods:
class ReplayBuffer(tf.Module): """Abstract base class for TF-Agents replay buffer.""" def __init__(self, data_spec, capacity): """Initializes the replay buffer. Args: data_spec: A spec or a list/tuple/nest of specs describing a single item that can be stored in this buffer capacity: number of elements that the replay buffer can hold. """ @property def data_spec(self): """Returns the spec for items in the replay buffer.""" @property def capacity(self): """Returns the capacity of the replay buffer.""" def add_batch(self, items): """Adds a batch of items to the replay buffer.""" def get_next(self, sample_batch_size=None, num_steps=None, time_stacked=True): """Returns an item or batch of items from the buffer.""" def as_dataset(self, sample_batch_size=None, num_steps=None, num_parallel_calls=None): """Creates and returns a dataset that returns entries from the buffer.""" def gather_all(self): """Returns all the items in buffer.""" return self._gather_all() def clear(self): """Resets the contents of replay buffer"""
Note that when the replay buffer object is initialized, it requires the
data_spec of the elements that it will store. This spec corresponds to the
TensorSpec of trajectory elements that will be added to the buffer. This spec is usually acquired by looking at an agent's
agent.collect_data_spec which defines the shapes, types, and structures expected by the agent when training (more on that later)
TFUniformReplayBuffer is the most commonly used replay buffer in TF-Agents, thus we will use it in our tutorial here. In
TFUniformReplayBuffer the backing buffer storage is done by tensorflow variables and thus is part of the compute graph.
The buffer stores batches of elements and has a maximum capacity
max_length elements per batch segment. Thus, the total buffer capacity is
max_length elements. The elements stored in the buffer must all have a matching data spec. When the replay buffer is used for data collection, the spec is the agent's collect data spec.
Creating the buffer:
To create a
TFUniformReplayBuffer we pass in:
- the spec of the data elements that the buffer will store
batch sizecorresponding to the batch size of the buffer
max_lengthnumber of elements per batch segment
Here is an example of creating a
TFUniformReplayBuffer with sample data specs,
batch_size 32 and
data_spec = ( tf.TensorSpec(, tf.float32, 'action'), ( tf.TensorSpec(, tf.float32, 'lidar'), tf.TensorSpec([3, 2], tf.float32, 'camera') ) ) batch_size = 32 max_length = 1000 replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( data_spec, batch_size=batch_size, max_length=max_length)
Writing to the buffer:
To add elements to the replay buffer, we use the
add_batch(items) method where
items is a list/tuple/nest of tensors representing the batch of items to be added to the buffer. Each element of
items must have an outer dimension equal
batch_size and the remaining dimensions must adhere to the data spec of the item (same as the data specs passed to the replay buffer constructor).
Here's an example of adding a batch of items
action = tf.constant(1 * np.ones( data_spec.shape.as_list(), dtype=np.float32)) lidar = tf.constant( 2 * np.ones(data_spec.shape.as_list(), dtype=np.float32)) camera = tf.constant( 3 * np.ones(data_spec.shape.as_list(), dtype=np.float32)) values = (action, (lidar, camera)) values_batched = tf.nest.map_structure(lambda t: tf.stack([t] * batch_size), values) replay_buffer.add_batch(values_batched)
Reading from the buffer
There are three ways to read data from the
get_next()- returns one sample from the buffer. The sample batch size and number of timesteps returned can be specified via arguments to this method.
as_dataset()- returns the replay buffer as a
tf.data.Dataset. One can then create a dataset iterator and iterate through the samples of the items in the buffer.
gather_all()- returns all the items in the buffer as a Tensor with shape
[batch, time, data_spec]
Below are examples of how to read from the replay buffer using each of these methods:
# add more items to the buffer before reading for _ in range(5): replay_buffer.add_batch(values_batched) # Get one sample from the replay buffer with batch size 10 and 1 timestep: sample = replay_buffer.get_next(sample_batch_size=10, num_steps=1) # Convert the replay buffer to a tf.data.Dataset and iterate through it dataset = replay_buffer.as_dataset( sample_batch_size=4, num_steps=2) iterator = iter(dataset) print("Iterator trajectories:") trajectories =  for _ in range(3): t, _ = next(iterator) trajectories.append(t) print(tf.nest.map_structure(lambda t: t.shape, trajectories)) # Read all elements in the replay buffer: trajectories = replay_buffer.gather_all() print("Trajectories from gather all:") print(tf.nest.map_structure(lambda t: t.shape, trajectories))
Iterator trajectories: [(TensorShape([4, 2, 3]), (TensorShape([4, 2, 5]), TensorShape([4, 2, 3, 2]))), (TensorShape([4, 2, 3]), (TensorShape([4, 2, 5]), TensorShape([4, 2, 3, 2]))), (TensorShape([4, 2, 3]), (TensorShape([4, 2, 5]), TensorShape([4, 2, 3, 2])))] Trajectories from gather all: (TensorShape([32, 6, 3]), (TensorShape([32, 6, 5]), TensorShape([32, 6, 3, 2])))
PyUniformReplayBuffer has the same functionaly as the
TFUniformReplayBuffer but instead of tf variables, it's data is stored in numpy arrays. This buffer can be used for out-of-graph data collection. Having the backing storage in numpy may make it easier for some applications to do data manipulation (such as indexing for updating priorities) without using Tensorflow variables. However, this implementation won't have the benefit of graph optimizations with Tensorflow.
Below is an example of instantiating a
PyUniformReplayBuffer from the agent's policy trajectory specs:
replay_buffer_capacity = 1000*32 # same capacity as the TFUniformReplayBuffer py_replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer( capacity=replay_buffer_capacity, data_spec=tensor_spec.to_nest_array_spec(data_spec))
Using replay buffers during training
Now that we know how to created a replay buffer, write items to it and read from it, we can use it to store trajectories during training of our agents.
First, let's look at how to use the replay buffer during data collection.
In TF-Agents we use a
Driver (see the Driver tutorial for more details) to collect experience in an environment. To use a
Driver, we specify an
Observer that is a function for the
Driver to execute when it receives a trajectory.
Thus, to add trajectory elements to the replay buffer, we add an observer that calls
add_batch(items) to add a (batch of) items on the replay buffer.
Below is an example of this with
TFUniformReplayBuffer. We first create an environment, a network and an agent. Then we create a
TFUniformReplayBuffer. Note that the specs of the trajectory elements in the replay buffer are equal to the agent's collect data spec. We then set its
add_batch method as the observer for the driver that will do the data collect during our training:
env = suite_gym.load('CartPole-v0') tf_env = tf_py_environment.TFPyEnvironment(env) q_net = q_network.QNetwork( tf_env.time_step_spec().observation, tf_env.action_spec(), fc_layer_params=(100,)) agent = dqn_agent.DqnAgent( tf_env.time_step_spec(), tf_env.action_spec(), q_network=q_net, optimizer=tf.compat.v1.train.AdamOptimizer(0.001)) replay_buffer_capacity = 1000 replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer( agent.collect_data_spec, batch_size=tf_env.batch_size, max_length=replay_buffer_capacity) # Add an observer that adds to the replay buffer: replay_observer = [replay_buffer.add_batch] collect_steps_per_iteration = 10 collect_op = dynamic_step_driver.DynamicStepDriver( tf_env, agent.collect_policy, observers=replay_observer, num_steps=collect_steps_per_iteration).run()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tf_agents/drivers/dynamic_step_driver.py:201: calling while_loop_v2 (from tensorflow.python.ops.control_flow_ops) with back_prop=False is deprecated and will be removed in a future version. Instructions for updating: back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.while_loop(c, b, vars, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))
Reading data for a train step
After adding trajectory elements to the replay buffer, we can read batches of trajectories from the replay buffer to use as input data for a train step.
Here is an example of how to train on trajectories from the replay buffer in a training loop:
# Read the replay buffer as a Dataset, # read batches of 4 elements, each with 2 timesteps: dataset = replay_buffer.as_dataset( sample_batch_size=4, num_steps=2) iterator = iter(dataset) num_train_steps = 10 for _ in range(num_train_steps): trajectories, _ = next(iterator) loss = agent.train(experience=trajectories)