Esta página foi traduzida pela API Cloud Translation.
Switch to English

Treine uma rede Deep Q com agentes TF

Ver em TensorFlow.org Executar no Google Colab Ver fonte no GitHub Download do caderno

Introdução

Este exemplo mostra como treinar um agente DQN (Deep Q Networks) no ambiente Cartpole usando a biblioteca TF-Agents.

Cartpole environment

Ele o guiará por todos os componentes em um pipeline de Aprendizagem por Reforço (RL) para treinamento, avaliação e coleta de dados.

Para executar esse código ao vivo, clique no link "Executar no Google Colab" acima.

Configuração

Se você não instalou as seguintes dependências, execute:

sudo apt-get install -y xvfb ffmpeg
pip install -q 'gym==0.10.11'
pip install -q 'imageio==2.4.0'
pip install -q PILLOW
pip install -q 'pyglet==1.3.2'
pip install -q pyvirtualdisplay
pip install -q --pre tf-agents[reverb]



ffmpeg is already the newest version (7:3.4.8-0ubuntu0.2).
xvfb is already the newest version (2:1.19.6-1ubuntu4.4).
The following packages were automatically installed and are no longer required:
  dconf-gsettings-backend dconf-service dkms freeglut3 freeglut3-dev
  glib-networking glib-networking-common glib-networking-services
  gsettings-desktop-schemas libcairo-gobject2 libcolord2 libdconf1
  libegl1-mesa libepoxy0 libglu1-mesa libglu1-mesa-dev libgtk-3-0
  libgtk-3-common libice-dev libjansson4 libjson-glib-1.0-0
  libjson-glib-1.0-common libproxy1v5 librest-0.7-0 libsm-dev
  libsoup-gnome2.4-1 libsoup2.4-1 libxi-dev libxmu-dev libxmu-headers
  libxnvctrl0 libxt-dev linux-gcp-headers-5.0.0-1026
  linux-headers-5.0.0-1026-gcp linux-image-5.0.0-1026-gcp
  linux-modules-5.0.0-1026-gcp pkg-config policykit-1-gnome python3-xkit
  screen-resolution-extra xserver-xorg-core-hwe-18.04
Use 'sudo apt autoremove' to remove them.
0 upgraded, 0 newly installed, 0 to remove and 90 not upgraded.
WARNING: You are using pip version 20.1.1; however, version 20.2 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
WARNING: You are using pip version 20.1.1; however, version 20.2 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
WARNING: You are using pip version 20.1.1; however, version 20.2 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
WARNING: You are using pip version 20.1.1; however, version 20.2 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
WARNING: You are using pip version 20.1.1; however, version 20.2 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.
WARNING: You are using pip version 20.1.1; however, version 20.2 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

 from __future__ import absolute_import, division, print_function

import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay

import tensorflow as tf

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
 
 tf.compat.v1.enable_v2_behavior()

# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
 
 tf.version.VERSION
 
'2.3.0'

Hiperparâmetros

 num_iterations = 20000 # @param {type:"integer"}

initial_collect_steps = 1000  # @param {type:"integer"} 
collect_steps_per_iteration = 1  # @param {type:"integer"}
replay_buffer_max_length = 100000  # @param {type:"integer"}

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
log_interval = 200  # @param {type:"integer"}

num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}
 

Meio Ambiente

No Aprendizado por Reforço (RL), um ambiente representa a tarefa ou o problema a ser resolvido. Ambientes padrão podem ser criados em TF-Agents usando os conjuntos tf_agents.environments . O TF-Agents possui suítes para carregar ambientes de fontes como o OpenAI Gym, Atari e DM Control.

Carregue o ambiente CartPole da suíte OpenAI Gym.

 env_name = 'CartPole-v0'
env = suite_gym.load(env_name)
 

Você pode renderizar esse ambiente para ver como ele se parece. Um poste de balanço livre está preso a um carrinho. O objetivo é mover o carrinho para a direita ou esquerda, a fim de manter o poste apontando para cima.

 
env.reset()
PIL.Image.fromarray(env.render())
 

png

O environment.step método tem uma action no ambiente e retorna um TimeStep tupla contendo a seguinte observação do ambiente e a recompensa para a acção.

O método time_step_spec() retorna a especificação para a tupla TimeStep . Seu atributo de observation mostra a forma das observações, os tipos de dados e os intervalos de valores permitidos. O atributo de reward mostra os mesmos detalhes para a recompensa.

 print('Observation Spec:')
print(env.time_step_spec().observation)
 
Observation Spec:
BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])

 print('Reward Spec:')
print(env.time_step_spec().reward)
 
Reward Spec:
ArraySpec(shape=(), dtype=dtype('float32'), name='reward')

O método action_spec() retorna a forma, tipos de dados e valores permitidos de ações válidas.

 print('Action Spec:')
print(env.action_spec())
 
Action Spec:
BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)

No ambiente Cartpole:

  • observation é uma matriz de 4 carros alegóricos:
    • a posição e velocidade do carrinho
    • a posição angular e velocidade do polo
  • reward é um valor de flutuação escalar
  • action é um número inteiro escalar com apenas dois valores possíveis:
    • 0 - "mover para a esquerda"
    • 1 - "mova para a direita"
 time_step = env.reset()
print('Time step:')
print(time_step)

action = np.array(1, dtype=np.int32)

next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
 
Time step:
TimeStep(step_type=array(0, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.03945196, -0.01603654, -0.02420856,  0.01585053], dtype=float32))
Next time step:
TimeStep(step_type=array(1, dtype=int32), reward=array(1., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.03913123,  0.17942408, -0.02389155, -0.2843711 ], dtype=float32))

Geralmente, dois ambientes são instanciados: um para treinamento e outro para avaliação.

 train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)
 

O ambiente Cartpole, como a maioria dos ambientes, é escrito em Python puro. Isso é convertido em TensorFlow usando o wrapper TFPyEnvironment .

A API do ambiente original usa matrizes Numpy. O TFPyEnvironment converte em Tensors para torná-lo compatível com os agentes e políticas do Tensorflow.

 train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
 

Agente

O algoritmo usado para resolver um problema de RL é representado por um Agent . O TF-Agents fornece implementações padrão de uma variedade de Agents , incluindo:

O agente DQN pode ser usado em qualquer ambiente que tenha um espaço de ação discreto.

No coração de um agente DQN está uma QNetwork , um modelo de rede neural que pode aprender a prever QValues (retornos esperados) para todas as ações, com uma observação do ambiente.

Use tf_agents.networks.q_network para criar uma QNetwork , transmitindo a observation_spec , action_spec e uma tupla que descreve o número e o tamanho das camadas ocultas do modelo.

 fc_layer_params = (100,)

q_net = q_network.QNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)
 

Agora use tf_agents.agents.dqn.dqn_agent para instanciar um DqnAgent . Além do time_step_spec , action_spec e QNetwork, o construtor do agente também requer um otimizador (neste caso, AdamOptimizer ), uma função de perda e um contador de etapas inteiro.

 optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

agent.initialize()
 

Políticas

Uma política define a maneira como um agente age em um ambiente. Normalmente, o objetivo do aprendizado por reforço é treinar o modelo subjacente até que a política produza o resultado desejado.

Neste tutorial:

  • O resultado desejado é manter o poste equilibrado na vertical sobre o carrinho.
  • A política retorna uma ação (esquerda ou direita) para cada observação time_step .

Os agentes contêm duas políticas:

  • agent.policy - A principal política usada para avaliação e implantação.
  • agent.collect_policy - Uma segunda política usada para coleta de dados.
 eval_policy = agent.policy
collect_policy = agent.collect_policy
 

As políticas podem ser criadas independentemente dos agentes. Por exemplo, use tf_agents.policies.random_tf_policy para criar uma política que selecione aleatoriamente uma ação para cada time_step .

 random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())
 

Para obter uma ação de uma política, chame o policy.action(time_step) . O time_step contém a observação do ambiente. Esse método retorna um PolicyStep , que é uma tupla nomeada com três componentes:

  • action - a ação a ser tomada (neste caso, 0 ou 1 )
  • state - usado para políticas com estado (ou seja, baseadas em RNN)
  • info - dados auxiliares, como probabilidades de log de ações
 example_environment = tf_py_environment.TFPyEnvironment(
    suite_gym.load('CartPole-v0'))
 
 time_step = example_environment.reset()
 
 random_policy.action(time_step)
 
PolicyStep(action=<tf.Tensor: shape=(1,), dtype=int64, numpy=array([1])>, state=(), info=())

Métricas e Avaliação

A métrica mais comum usada para avaliar uma política é o retorno médio. O retorno é a soma das recompensas obtidas durante a execução de uma política em um ambiente para um episódio. Vários episódios são executados, criando um retorno médio.

A função a seguir calcula o retorno médio de uma política, considerando a política, o ambiente e vários episódios.

 
def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


# See also the metrics module for standard implementations of different metrics.
# https://github.com/tensorflow/agents/tree/master/tf_agents/metrics
 

A execução desse cálculo na random_policy mostra um desempenho de linha de base no ambiente.

 compute_avg_return(eval_env, random_policy, num_eval_episodes)
 
21.2

Replay Buffer

O buffer de reprodução controla os dados coletados do ambiente. Este tutorial usa tf_agents.replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer , pois é o mais comum.

O construtor requer as especificações para os dados que coletará. Está disponível no agente usando o método collect_data_spec . O tamanho do lote e o tamanho máximo do buffer também são necessários.

 replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length)
 

Para a maioria dos agentes, collect_data_spec é uma tupla nomeada chamada Trajectory , contendo as especificações para observações, ações, recompensas e outros itens.

 agent.collect_data_spec
 
Trajectory(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), observation=BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32)), action=BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0), maximum=array(1)), policy_info=(), next_step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)))
 agent.collect_data_spec._fields
 
('step_type',
 'observation',
 'action',
 'policy_info',
 'next_step_type',
 'reward',
 'discount')

Coleção de dados

Agora execute a política aleatória no ambiente por algumas etapas, registrando os dados no buffer de reprodução.

 
def collect_step(environment, policy, buffer):
  time_step = environment.current_time_step()
  action_step = policy.action(time_step)
  next_time_step = environment.step(action_step.action)
  traj = trajectory.from_transition(time_step, action_step, next_time_step)

  # Add trajectory to the replay buffer
  buffer.add_batch(traj)

def collect_data(env, policy, buffer, steps):
  for _ in range(steps):
    collect_step(env, policy, buffer)

collect_data(train_env, random_policy, replay_buffer, steps=100)

# This loop is so common in RL, that we provide standard implementations. 
# For more details see the drivers module.
# https://www.tensorflow.org/agents/api_docs/python/tf_agents/drivers
 

O buffer de reprodução agora é uma coleção de Trajetórias.

 # For the curious:
# Uncomment to peel one of these off and inspect it.
# iter(replay_buffer.as_dataset()).next()
 

O agente precisa acessar o buffer de reprodução. Isso é fornecido através da criação de um pipeline iterável tf.data.Dataset que alimenta os dados para o agente.

Cada linha do buffer de reprodução armazena apenas uma única etapa de observação. Porém, como o DQN Agent precisa da observação atual e da próxima para calcular a perda, o pipeline do conjunto de dados fará uma amostra de duas linhas adjacentes para cada item do lote ( num_steps=2 ).

Esse conjunto de dados também é otimizado executando chamadas paralelas e pré-buscando dados.

 # Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, 
    sample_batch_size=batch_size, 
    num_steps=2).prefetch(3)


dataset
 
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/autograph/operators/control_flow.py:1004: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.

<PrefetchDataset shapes: (Trajectory(step_type=(64, 2), observation=(64, 2, 4), action=(64, 2), policy_info=(), next_step_type=(64, 2), reward=(64, 2), discount=(64, 2)), BufferInfo(ids=(64, 2), probabilities=(64,))), types: (Trajectory(step_type=tf.int32, observation=tf.float32, action=tf.int64, policy_info=(), next_step_type=tf.int32, reward=tf.float32, discount=tf.float32), BufferInfo(ids=tf.int64, probabilities=tf.float32))>
 iterator = iter(dataset)

print(iterator)

 
<tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x7f49c855d1d0>

 # For the curious:
# Uncomment to see what the dataset iterator is feeding to the agent.
# Compare this representation of replay data 
# to the collection of individual trajectories shown earlier.

# iterator.next()
 

Treinando o agente

Duas coisas devem acontecer durante o ciclo de treinamento:

  • coletar dados do ambiente
  • use esses dados para treinar as redes neurais do agente

Este exemplo também avalia periodicamente a política e imprime a pontuação atual.

O seguinte levará aproximadamente 5 minutos para ser executado.

 
try:
  %%time
except:
  pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

  # Collect a few steps using collect_policy and save to the replay buffer.
  for _ in range(collect_steps_per_iteration):
    collect_step(train_env, agent.collect_policy, replay_buffer)

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience).loss

  step = agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    returns.append(avg_return)
 
step = 200: loss = 19.690689086914062
step = 400: loss = 16.47402572631836
step = 600: loss = 4.861778736114502
step = 800: loss = 4.024849891662598
step = 1000: loss = 2.6150379180908203
step = 1000: Average Return = 13.899999618530273
step = 1200: loss = 8.585780143737793
step = 1400: loss = 4.644379615783691
step = 1600: loss = 5.039738655090332
step = 1800: loss = 11.336905479431152
step = 2000: loss = 3.268812656402588
step = 2000: Average Return = 33.5
step = 2200: loss = 11.0034761428833
step = 2400: loss = 16.17085075378418
step = 2600: loss = 5.645272254943848
step = 2800: loss = 9.536992073059082
step = 3000: loss = 27.910123825073242
step = 3000: Average Return = 42.5
step = 3200: loss = 16.461536407470703
step = 3400: loss = 35.99446105957031
step = 3600: loss = 17.244731903076172
step = 3800: loss = 64.8316421508789
step = 4000: loss = 45.81110382080078
step = 4000: Average Return = 66.0999984741211
step = 4200: loss = 38.792320251464844
step = 4400: loss = 24.844989776611328
step = 4600: loss = 62.42521667480469
step = 4800: loss = 5.556773662567139
step = 5000: loss = 56.66333770751953
step = 5000: Average Return = 179.1999969482422
step = 5200: loss = 31.51094627380371
step = 5400: loss = 5.266134262084961
step = 5600: loss = 96.6891098022461
step = 5800: loss = 127.47486114501953
step = 6000: loss = 12.055898666381836
step = 6000: Average Return = 193.10000610351562
step = 6200: loss = 27.283191680908203
step = 6400: loss = 8.846441268920898
step = 6600: loss = 19.135318756103516
step = 6800: loss = 279.4776916503906
step = 7000: loss = 41.02388000488281
step = 7000: Average Return = 200.0
step = 7200: loss = 228.2138214111328
step = 7400: loss = 151.24070739746094
step = 7600: loss = 62.31920623779297
step = 7800: loss = 16.921720504760742
step = 8000: loss = 37.32741165161133
step = 8000: Average Return = 200.0
step = 8200: loss = 35.77818298339844
step = 8400: loss = 56.06548309326172
step = 8600: loss = 388.6180114746094
step = 8800: loss = 98.04512023925781
step = 9000: loss = 102.46206665039062
step = 9000: Average Return = 200.0
step = 9200: loss = 72.02937316894531
step = 9400: loss = 215.56304931640625
step = 9600: loss = 298.1328430175781
step = 9800: loss = 273.0595397949219
step = 10000: loss = 31.24103546142578
step = 10000: Average Return = 200.0
step = 10200: loss = 493.45379638671875
step = 10400: loss = 321.92828369140625
step = 10600: loss = 127.94703674316406
step = 10800: loss = 105.52901458740234
step = 11000: loss = 10.25836181640625
step = 11000: Average Return = 200.0
step = 11200: loss = 34.508384704589844
step = 11400: loss = 292.8785400390625
step = 11600: loss = 282.56976318359375
step = 11800: loss = 419.93927001953125
step = 12000: loss = 149.61669921875
step = 12000: Average Return = 200.0
step = 12200: loss = 662.1378173828125
step = 12400: loss = 6264.18017578125
step = 12600: loss = 1844.2698974609375
step = 12800: loss = 1308.2515869140625
step = 13000: loss = 1177.77001953125
step = 13000: Average Return = 200.0
step = 13200: loss = 425.58929443359375
step = 13400: loss = 51.96171188354492
step = 13600: loss = 104.29891967773438
step = 13800: loss = 535.3070678710938
step = 14000: loss = 49.53106689453125
step = 14000: Average Return = 200.0
step = 14200: loss = 942.129638671875
step = 14400: loss = 52.236656188964844
step = 14600: loss = 69.46319580078125
step = 14800: loss = 60.80382537841797
step = 15000: loss = 784.780517578125
step = 15000: Average Return = 200.0
step = 15200: loss = 211.77589416503906
step = 15400: loss = 3985.186279296875
step = 15600: loss = 258.67987060546875
step = 15800: loss = 1084.543701171875
step = 16000: loss = 1418.936767578125
step = 16000: Average Return = 199.3000030517578
step = 16200: loss = 135.34539794921875
step = 16400: loss = 469.0115966796875
step = 16600: loss = 65.48590087890625
step = 16800: loss = 1503.060791015625
step = 17000: loss = 179.79794311523438
step = 17000: Average Return = 200.0
step = 17200: loss = 74.41525268554688
step = 17400: loss = 410.6675720214844
step = 17600: loss = 89.18785095214844
step = 17800: loss = 107.3138198852539
step = 18000: loss = 138.09153747558594
step = 18000: Average Return = 200.0
step = 18200: loss = 52.20016098022461
step = 18400: loss = 76.33519744873047
step = 18600: loss = 954.2196655273438
step = 18800: loss = 210.16331481933594
step = 19000: loss = 729.912841796875
step = 19000: Average Return = 200.0
step = 19200: loss = 80.44801330566406
step = 19400: loss = 4124.3759765625
step = 19600: loss = 54.86518859863281
step = 19800: loss = 144.1697235107422
step = 20000: loss = 28250.6171875
step = 20000: Average Return = 200.0

Visualização

Parcelas

Use matplotlib.pyplot para mapear como a política melhorou durante o treinamento.

Uma iteração do Cartpole-v0 consiste em 200 etapas de tempo. O ambiente oferece uma recompensa de +1 para cada etapa em que o poste permanece, de modo que o retorno máximo de um episódio é 200. Os gráficos mostram o retorno aumentando para esse máximo cada vez que é avaliado durante o treinamento. (Pode ser um pouco instável e não aumentar monotonicamente a cada vez.)

 

iterations = range(0, num_iterations + 1, eval_interval)
plt.plot(iterations, returns)
plt.ylabel('Average Return')
plt.xlabel('Iterations')
plt.ylim(top=250)
 
(0.5, 250.0)

png

Vídeos

Os gráficos são legais. Mais emocionante, porém, é ver um agente realizando uma tarefa em um ambiente.

Primeiro, crie uma função para incorporar vídeos no notebook.

 def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)
 

Agora, repita alguns episódios do jogo Cartpole com o agente. O ambiente Python subjacente (aquele "dentro" do wrapper de ambiente TensorFlow) fornece um método render() , que gera uma imagem do estado do ambiente. Estes podem ser coletados em um vídeo.

 def create_policy_eval_video(policy, filename, num_episodes=5, fps=30):
  filename = filename + ".mp4"
  with imageio.get_writer(filename, fps=fps) as video:
    for _ in range(num_episodes):
      time_step = eval_env.reset()
      video.append_data(eval_py_env.render())
      while not time_step.is_last():
        action_step = policy.action(time_step)
        time_step = eval_env.step(action_step.action)
        video.append_data(eval_py_env.render())
  return embed_mp4(filename)




create_policy_eval_video(agent.policy, "trained-agent")
 
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.

Por diversão, compare o agente treinado (acima) com um agente que se move aleatoriamente. (Não funciona tão bem.)

 create_policy_eval_video(random_policy, "random-agent")
 
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.