Entrene una red Deep Q con TF-Agents

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Introducción

Este ejemplo muestra cómo entrenar a un (Deep Q Redes) DQN agente en el medio ambiente Cartpole usando la biblioteca TF-Agentes.

Entorno de cartpole

Lo guiará a través de todos los componentes de una canalización de aprendizaje reforzado (RL) para capacitación, evaluación y recopilación de datos.

Para ejecutar este código en vivo, haga clic en el enlace 'Ejecutar en Google Colab' arriba.

Configuración

Si no ha instalado las siguientes dependencias, ejecute:

sudo apt-get update
sudo apt-get install -y xvfb ffmpeg freeglut3-dev
pip install 'imageio==2.4.0'
pip install pyvirtualdisplay
pip install tf-agents[reverb]
pip install pyglet
from __future__ import absolute_import, division, print_function

import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay
import reverb

import tensorflow as tf

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import py_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import sequential
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.trajectories import trajectory
from tf_agents.specs import tensor_spec
from tf_agents.utils import common
# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
tf.version.VERSION
'2.6.0'

Hiperparámetros

num_iterations = 20000 # @param {type:"integer"}

initial_collect_steps = 100  # @param {type:"integer"}
collect_steps_per_iteration =   1# @param {type:"integer"}
replay_buffer_max_length = 100000  # @param {type:"integer"}

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
log_interval = 200  # @param {type:"integer"}

num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}

Medio ambiente

En el aprendizaje por refuerzo (RL), un entorno representa la tarea o problema a resolver. Entornos estándar se pueden crear en la carretera TF-Agentes utilizando tf_agents.environments suites. TF-Agents tiene suites para cargar entornos de fuentes como OpenAI Gym, Atari y DM Control.

Cargue el entorno CartPole desde la suite OpenAI Gym.

env_name = 'CartPole-v0'
env = suite_gym.load(env_name)

Puede renderizar este entorno para ver cómo se ve. Un poste de oscilación libre está sujeto a un carro. El objetivo es mover el carro hacia la derecha o hacia la izquierda para mantener el poste apuntando hacia arriba.

env.reset()
PIL.Image.fromarray(env.render())

png

El environment.step método toma una action en el medio ambiente y devuelve un TimeStep tupla que contiene la siguiente observación del entorno y la recompensa de la acción.

El time_step_spec() método devuelve la especificación para el TimeStep tupla. Sus observation de atributos muestra la forma de observaciones, los tipos de datos, y los rangos de valores permitidos. La reward atributo muestra los mismos detalles para la recompensa.

print('Observation Spec:')
print(env.time_step_spec().observation)
Observation Spec:
BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])
print('Reward Spec:')
print(env.time_step_spec().reward)
Reward Spec:
ArraySpec(shape=(), dtype=dtype('float32'), name='reward')

El action_spec() método devuelve la forma, tipos de datos y valores permitidos de acciones válidas.

print('Action Spec:')
print(env.action_spec())
Action Spec:
BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)

En el entorno de Cartpole:

  • observation es una matriz de 4 flotadores:
    • la posición y velocidad del carro
    • la posición angular y la velocidad del polo
  • reward es un valor flotante escalar
  • action es un número entero escalar con sólo dos valores posibles:
    • 0 - "mover hacia la izquierda"
    • 1 - "decisión correcta"
time_step = env.reset()
print('Time step:')
print(time_step)

action = np.array(1, dtype=np.int32)

next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
Time step:
TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([ 0.02281696, -0.00137907,  0.04442764, -0.03935837], dtype=float32),
 'reward': array(0., dtype=float32),
 'step_type': array(0, dtype=int32)})
Next time step:
TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([ 0.02278937,  0.19307856,  0.04364047, -0.31769958], dtype=float32),
 'reward': array(1., dtype=float32),
 'step_type': array(1, dtype=int32)})

Por lo general, se crean instancias de dos entornos: uno para entrenamiento y otro para evaluación.

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

El entorno Cartpole, como la mayoría de los entornos, está escrito en Python puro. Esta se convierte en TensorFlow usando el TFPyEnvironment envoltura.

La API del entorno original utiliza matrices Numpy. Los TFPyEnvironment convertidos a estos Tensors para que sea compatible con los agentes y políticas Tensorflow.

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

Agente

El algoritmo utilizado para resolver un problema de RL está representado por un Agent . TF-Agentes proporciona implementaciones estándar de una variedad de Agents , que incluyen:

El agente DQN se puede utilizar en cualquier entorno que tenga un espacio de acción discreto.

En el corazón de un agente DQN es un QNetwork , un modelo de red neuronal que se puede aprender a predecir QValues (retornos esperados) para todas las acciones, teniendo en cuenta una observación del entorno.

Vamos a utilizar tf_agents.networks. para crear un QNetwork . La red consistirá en una secuencia de tf.keras.layers.Dense capas, donde la capa final tendrá 1 de salida para cada acción posible.

fc_layer_params = (100, 50)
action_tensor_spec = tensor_spec.from_spec(env.action_spec())
num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

# Define a helper function to create Dense layers configured with the right
# activation and kernel initializer.
def dense_layer(num_units):
  return tf.keras.layers.Dense(
      num_units,
      activation=tf.keras.activations.relu,
      kernel_initializer=tf.keras.initializers.VarianceScaling(
          scale=2.0, mode='fan_in', distribution='truncated_normal'))

# QNetwork consists of a sequence of Dense layers followed by a dense layer
# with `num_actions` units to generate one q_value per available action as
# its output.
dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
q_values_layer = tf.keras.layers.Dense(
    num_actions,
    activation=None,
    kernel_initializer=tf.keras.initializers.RandomUniform(
        minval=-0.03, maxval=0.03),
    bias_initializer=tf.keras.initializers.Constant(-0.2))
q_net = sequential.Sequential(dense_layers + [q_values_layer])

Ahora usa tf_agents.agents.dqn.dqn_agent a instancias de un DqnAgent . Además de la time_step_spec , action_spec y la QNetwork, el constructor agente también requiere un optimizador (en este caso, AdamOptimizer ), una función de pérdida, y un contador de paso entero.

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

agent.initialize()

Políticas

Una política define la forma en que un agente actúa en un entorno. Normalmente, el objetivo del aprendizaje por refuerzo es entrenar el modelo subyacente hasta que la política produzca el resultado deseado.

En este tutorial:

  • El resultado deseado es mantener el poste en equilibrio sobre el carro.
  • La política devuelve una acción (izquierda o derecha) para cada time_step observación.

Los agentes contienen dos políticas:

  • agent.policy - La principal política que se utiliza para la evaluación y la implementación.
  • agent.collect_policy - Una segunda política que se utiliza para la recolección de datos.
eval_policy = agent.policy
collect_policy = agent.collect_policy

Las políticas se pueden crear independientemente de los agentes. Por ejemplo, utilice tf_agents.policies.random_tf_policy para crear una política que seleccionará al azar una acción para cada time_step .

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

Para conseguir una acción de una política, llame al policy.action(time_step) método. El time_step contiene la observación del entorno. Este método devuelve un PolicyStep , que es una tupla llamada con tres componentes:

  • action - la acción a tomar (en este caso, 0 o 1 )
  • state - se utiliza para las políticas de estado (es decir, RNN-based)
  • info - datos auxiliares, tales como las probabilidades de registro de acciones
example_environment = tf_py_environment.TFPyEnvironment(
    suite_gym.load('CartPole-v0'))
time_step = example_environment.reset()
random_policy.action(time_step)
PolicyStep(action=<tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>, state=(), info=())

Métricas y evaluación

La métrica más común utilizada para evaluar una póliza es el rendimiento promedio. El retorno es la suma de las recompensas obtenidas al ejecutar una política en un entorno para un episodio. Se ejecutan varios episodios, lo que genera un rendimiento medio.

La siguiente función calcula el rendimiento promedio de una política, dados la política, el entorno y una serie de episodios.

def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


# See also the metrics module for standard implementations of different metrics.
# https://github.com/tensorflow/agents/tree/master/tf_agents/metrics

La ejecución de este cálculo en la random_policy muestra un rendimiento de referencia en el entorno.

compute_avg_return(eval_env, random_policy, num_eval_episodes)
21.1

Búfer de reproducción

Con el fin de realizar un seguimiento de los datos recogidos en el medio ambiente, vamos a utilizar reverberación , un sistema de repetición eficiente, extensible y fácil de usar por Deepmind. Almacena datos de experiencia cuando recopilamos trayectorias y se consume durante el entrenamiento.

Este búfer de reproducción se construye utilizando especificaciones que describen los tensores que se almacenarán, que se pueden obtener del agente mediante agent.collect_data_spec.

table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(
      agent.collect_data_spec)
table = reverb.Table(
    table_name,
    max_size=replay_buffer_max_length,
    sampler=reverb.selectors.Uniform(),
    remover=reverb.selectors.Fifo(),
    rate_limiter=reverb.rate_limiters.MinSize(1),
    signature=replay_buffer_signature)

reverb_server = reverb.Server([table])

replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
    agent.collect_data_spec,
    table_name=table_name,
    sequence_length=2,
    local_server=reverb_server)

rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
  replay_buffer.py_client,
  table_name,
  sequence_length=2)
[reverb/cc/platform/tfrecord_checkpointer.cc:150]  Initializing TFRecordCheckpointer in /tmp/tmpa2d5wjuv.
[reverb/cc/platform/tfrecord_checkpointer.cc:380] Loading latest checkpoint from /tmp/tmpa2d5wjuv
[reverb/cc/platform/default/server.cc:71] Started replay server on port 17916

Para la mayoría de los agentes, collect_data_spec es una tupla con nombre denominada Trajectory , que contiene las especificaciones para las observaciones, acciones, recompensas y otros artículos.

agent.collect_data_spec
Trajectory(
{'action': BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0), maximum=array(1)),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'next_step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'observation': BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32)),
 'policy_info': (),
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
 'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type')})
agent.collect_data_spec._fields
('step_type',
 'observation',
 'action',
 'policy_info',
 'next_step_type',
 'reward',
 'discount')

Recopilación de datos

Ahora ejecute la política aleatoria en el entorno durante algunos pasos, registrando los datos en el búfer de reproducción.

Aquí estamos usando 'PyDriver' para ejecutar el ciclo de recolección de experiencias. Usted puede aprender más acerca controlador Agentes TF en nuestro tutorial de los conductores .

py_driver.PyDriver(
    env,
    py_tf_eager_policy.PyTFEagerPolicy(
      random_policy, use_tf_function=True),
    [rb_observer],
    max_steps=initial_collect_steps).run(train_py_env.reset())
(TimeStep(
 {'discount': array(1., dtype=float32),
  'observation': array([ 0.09068768,  1.0256505 , -0.19590192, -1.8110262 ], dtype=float32),
  'reward': array(1., dtype=float32),
  'step_type': array(1, dtype=int32)}),
 ())

El búfer de reproducción ahora es una colección de trayectorias.

# For the curious:
# Uncomment to peel one of these off and inspect it.
# iter(replay_buffer.as_dataset()).next()

El agente necesita acceso al búfer de reproducción. Esto se proporciona mediante la creación de un iterable tf.data.Dataset tubería que alimentarán datos al agente.

Cada fila del búfer de reproducción solo almacena un único paso de observación. Pero dado que el agente DQN necesita tanto la observación actual y la siguiente para calcular la pérdida, la tubería conjunto de datos se muestra dos filas adyacentes de cada elemento en el lote ( num_steps=2 ).

Este conjunto de datos también se optimiza mediante la ejecución de llamadas paralelas y la obtención previa de datos.

# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3,
    sample_batch_size=batch_size,
    num_steps=2).prefetch(3)

dataset
<PrefetchDataset shapes: (Trajectory(
{action: (64, 2),
 discount: (64, 2),
 next_step_type: (64, 2),
 observation: (64, 2, 4),
 policy_info: (),
 reward: (64, 2),
 step_type: (64, 2)}), SampleInfo(key=(64, 2), probability=(64, 2), table_size=(64, 2), priority=(64, 2))), types: (Trajectory(
{action: tf.int64,
 discount: tf.float32,
 next_step_type: tf.int32,
 observation: tf.float32,
 policy_info: (),
 reward: tf.float32,
 step_type: tf.int32}), SampleInfo(key=tf.uint64, probability=tf.float64, table_size=tf.int64, priority=tf.float64))>
iterator = iter(dataset)
print(iterator)
<tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x7fbc63fd6990>
# For the curious:
# Uncomment to see what the dataset iterator is feeding to the agent.
# Compare this representation of replay data 
# to the collection of individual trajectories shown earlier.

# iterator.next()

Entrenando al agente

Deben suceder dos cosas durante el ciclo de entrenamiento:

  • recopilar datos del medio ambiente
  • usar esos datos para entrenar las redes neuronales del agente

Este ejemplo también evalúa periódicamente la política e imprime la puntuación actual.

Lo siguiente tardará ~ 5 minutos en ejecutarse.

try:
  %%time
except:
  pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step.
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

# Reset the environment.
time_step = train_py_env.reset()

# Create a driver to collect experience.
collect_driver = py_driver.PyDriver(
    env,
    py_tf_eager_policy.PyTFEagerPolicy(
      agent.collect_policy, use_tf_function=True),
    [rb_observer],
    max_steps=collect_steps_per_iteration)

for _ in range(num_iterations):

  # Collect a few steps and save to the replay buffer.
  time_step, _ = collect_driver.run(time_step)

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience).loss

  step = agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    returns.append(avg_return)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:206: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (13133) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (13133) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (13133) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (13133) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (13133) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (13133) so Table uniform_table is accessed directly without gRPC.
step = 200: loss = 16.648088455200195
step = 400: loss = 8.659958839416504
step = 600: loss = 54.335906982421875
step = 800: loss = 8.27525806427002
step = 1000: loss = 166.111572265625
step = 1000: Average Return = 122.0999984741211
step = 1200: loss = 313.59765625
step = 1400: loss = 37.847381591796875
step = 1600: loss = 288.4794616699219
step = 1800: loss = 3317.88525390625
step = 2000: loss = 173.12469482421875
step = 2000: Average Return = 199.8000030517578
step = 2200: loss = 6418.63525390625
step = 2400: loss = 1484.861328125
step = 2600: loss = 196.24331665039062
step = 2800: loss = 149.65438842773438
step = 3000: loss = 916.5035400390625
step = 3000: Average Return = 198.6999969482422
step = 3200: loss = 162.61080932617188
step = 3400: loss = 13968.1435546875
step = 3600: loss = 302.0409851074219
step = 3800: loss = 2836.1259765625
step = 4000: loss = 1076.712646484375
step = 4000: Average Return = 200.0
step = 4200: loss = 20416.47265625
step = 4400: loss = 1810.417236328125
step = 4600: loss = 3067.5947265625
step = 4800: loss = 3112.873291015625
step = 5000: loss = 6336.5771484375
step = 5000: Average Return = 200.0
step = 5200: loss = 6338.7412109375
step = 5400: loss = 5041.7099609375
step = 5600: loss = 4270.40869140625
step = 5800: loss = 16978.80859375
step = 6000: loss = 16609.892578125
step = 6000: Average Return = 200.0
step = 6200: loss = 7808.3515625
step = 6400: loss = 9949.4765625
step = 6600: loss = 19038.453125
step = 6800: loss = 38204.29296875
step = 7000: loss = 287291.90625
step = 7000: Average Return = 200.0
step = 7200: loss = 121852.53125
step = 7400: loss = 35913.296875
step = 7600: loss = 49503.0703125
step = 7800: loss = 49284.46875
step = 8000: loss = 100121.359375
step = 8000: Average Return = 200.0
step = 8200: loss = 56341.046875
step = 8400: loss = 1447406.5
step = 8600: loss = 140221.96875
step = 8800: loss = 150099.921875
step = 9000: loss = 163611.125
step = 9000: Average Return = 200.0
step = 9200: loss = 1166357.75
step = 9400: loss = 144310.484375
step = 9600: loss = 217644.84375
step = 9800: loss = 155856.6875
step = 10000: loss = 185425.078125
step = 10000: Average Return = 200.0
step = 10200: loss = 189458.25
step = 10400: loss = 308120.9375
step = 10600: loss = 455290.5
step = 10800: loss = 302644.6875
step = 11000: loss = 3637171.25
step = 11000: Average Return = 200.0
step = 11200: loss = 214862.234375
step = 11400: loss = 2360442.25
step = 11600: loss = 14211988.0
step = 11800: loss = 253858.125
step = 12000: loss = 649448.375
step = 12000: Average Return = 200.0
step = 12200: loss = 254648.890625
step = 12400: loss = 2927912.25
step = 12600: loss = 2999200.75
step = 12800: loss = 1350071.0
step = 13000: loss = 1172909.5
step = 13000: Average Return = 200.0
step = 13200: loss = 381366.5625
step = 13400: loss = 107656960.0
step = 13600: loss = 309260.84375
step = 13800: loss = 932287.0
step = 14000: loss = 1813161.5
step = 14000: Average Return = 200.0
step = 14200: loss = 1147879.125
step = 14400: loss = 2045360.875
step = 14600: loss = 752043.5625
step = 14800: loss = 800067.25
step = 15000: loss = 4558073.5
step = 15000: Average Return = 200.0
step = 15200: loss = 1908942.25
step = 15400: loss = 216246080.0
step = 15600: loss = 598498.3125
step = 15800: loss = 15048959.0
step = 16000: loss = 1509748.0
step = 16000: Average Return = 200.0
step = 16200: loss = 446024.5625
step = 16400: loss = 739561.25
step = 16600: loss = 4358108.5
step = 16800: loss = 2399731.0
step = 17000: loss = 350459584.0
step = 17000: Average Return = 200.0
step = 17200: loss = 2464587.25
step = 17400: loss = 3177516.0
step = 17600: loss = 775946.5625
step = 17800: loss = 2545362.5
step = 18000: loss = 4703361.0
step = 18000: Average Return = 200.0
step = 18200: loss = 5975770.5
step = 18400: loss = 143910896.0
step = 18600: loss = 1918198.25
step = 18800: loss = 3589433.0
step = 19000: loss = 4773322.5
step = 19000: Average Return = 200.0
step = 19200: loss = 5001252.0
step = 19400: loss = 9087488.0
step = 19600: loss = 3865313.5
step = 19800: loss = 3663483.75
step = 20000: loss = 2048965.25
step = 20000: Average Return = 200.0

Visualización

Parcelas

Uso matplotlib.pyplot para trazar cómo mejoró la política durante el entrenamiento.

Una iteración de Cartpole-v0 consiste en 200 pasos de tiempo. El entorno da una recompensa de +1 para cada paso de las estancias de polo hacia arriba, por lo que el rendimiento máximo para un episodio es 200. Los gráficos muestra la rentabilidad creciente hacia ese máximo cada vez que se evalúa durante el entrenamiento. (Puede ser un poco inestable y no aumentar monótonamente cada vez).

iterations = range(0, num_iterations + 1, eval_interval)
plt.plot(iterations, returns)
plt.ylabel('Average Return')
plt.xlabel('Iterations')
plt.ylim(top=250)
(0.1849997997283932, 250.0)

png

Videos

Los gráficos son agradables. Pero lo más emocionante es ver a un agente realizando una tarea en un entorno.

Primero, cree una función para incrustar videos en el cuaderno.

def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)

Ahora repite algunos episodios del juego Cartpole con el agente. El entorno Python subyacente (el "dentro" del entorno envoltura TensorFlow) proporciona un render() método, que da salida a una imagen del estado medio ambiente. Estos se pueden recopilar en un video.

def create_policy_eval_video(policy, filename, num_episodes=5, fps=30):
  filename = filename + ".mp4"
  with imageio.get_writer(filename, fps=fps) as video:
    for _ in range(num_episodes):
      time_step = eval_env.reset()
      video.append_data(eval_py_env.render())
      while not time_step.is_last():
        action_step = policy.action(time_step)
        time_step = eval_env.step(action_step.action)
        video.append_data(eval_py_env.render())
  return embed_mp4(filename)

create_policy_eval_video(agent.policy, "trained-agent")
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.
[swscaler @ 0x55a095bda3c0] Warning: data is not aligned! This can lead to a speed loss

Para divertirse, compare al agente entrenado (arriba) con un agente que se mueve al azar. (No funciona tan bien).

create_policy_eval_video(random_policy, "random-agent")
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.
[swscaler @ 0x55ac66ca23c0] Warning: data is not aligned! This can lead to a speed loss