Questa pagina è stata tradotta dall'API Cloud Translation.
Switch to English

Addestra una rete Deep Q con TF-Agents

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza sorgente su GitHub Scarica notebook

introduzione

Questo esempio mostra come addestrare un agente DQN (Deep Q Networks) sull'ambiente Cartpole utilizzando la libreria TF-Agents.

Ambiente cartpole

Ti guiderà attraverso tutti i componenti in una pipeline di Reinforcement Learning (RL) per la formazione, la valutazione e la raccolta dei dati.

Per eseguire questo codice in tempo reale, fai clic sul link "Esegui in Google Colab" sopra.

Impostare

Se non hai installato le seguenti dipendenze, esegui:

sudo apt-get install -y xvfb ffmpeg
pip install -q gym
pip install -q 'imageio==2.4.0'
pip install -q PILLOW
pip install -q pyglet
pip install -q pyvirtualdisplay
pip install -q tf-agents



ffmpeg is already the newest version (7:3.4.8-0ubuntu0.2).
xvfb is already the newest version (2:1.19.6-1ubuntu4.7).
The following packages were automatically installed and are no longer required:
  dconf-gsettings-backend dconf-service dkms freeglut3 freeglut3-dev
  glib-networking glib-networking-common glib-networking-services
  gsettings-desktop-schemas libcairo-gobject2 libcolord2 libdconf1
  libegl1-mesa libepoxy0 libglu1-mesa libglu1-mesa-dev libgtk-3-0
  libgtk-3-common libice-dev libjansson4 libjson-glib-1.0-0
  libjson-glib-1.0-common libproxy1v5 librest-0.7-0 libsm-dev
  libsoup-gnome2.4-1 libsoup2.4-1 libxi-dev libxmu-dev libxmu-headers
  libxnvctrl0 libxt-dev linux-gcp-headers-5.0.0-1026
  linux-headers-5.0.0-1026-gcp linux-image-5.0.0-1026-gcp
  linux-modules-5.0.0-1026-gcp pkg-config policykit-1-gnome python3-xkit
  screen-resolution-extra xserver-xorg-core-hwe-18.04
Use 'sudo apt autoremove' to remove them.
0 upgraded, 0 newly installed, 0 to remove and 91 not upgraded.

from __future__ import absolute_import, division, print_function

import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay

import tensorflow as tf

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
tf.compat.v1.enable_v2_behavior()

# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
tf.version.VERSION
'2.3.1'

Iperparametri

num_iterations = 20000 # @param {type:"integer"}

initial_collect_steps = 100  # @param {type:"integer"} 
collect_steps_per_iteration = 1  # @param {type:"integer"}
replay_buffer_max_length = 100000  # @param {type:"integer"}

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
log_interval = 200  # @param {type:"integer"}

num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}

Ambiente

In Reinforcement Learning (RL), un ambiente rappresenta il compito o il problema da risolvere. Gli ambienti standard possono essere creati in TF-Agent utilizzando tf_agents.environments suite tf_agents.environments . TF-Agents dispone di suite per il caricamento di ambienti da sorgenti come OpenAI Gym, Atari e DM Control.

Carica l'ambiente CartPole dalla suite OpenAI Gym.

env_name = 'CartPole-v0'
env = suite_gym.load(env_name)

Puoi eseguire il rendering di questo ambiente per vedere come appare. Un palo oscillante è fissato a un carrello. L'obiettivo è spostare il carrello a destra oa sinistra in modo da mantenere il palo rivolto verso l'alto.


env.reset()
PIL.Image.fromarray(env.render())

png

Il metodo environment.step esegue action nell'ambiente e restituisce una tupla TimeStep contenente l'osservazione successiva dell'ambiente e la ricompensa per l'azione.

Il metodo time_step_spec() restituisce la specifica per la tupla TimeStep . Il suo attributo di observation mostra la forma delle osservazioni, i tipi di dati e gli intervalli di valori consentiti. L'attributo del reward mostra gli stessi dettagli per il premio.

print('Observation Spec:')
print(env.time_step_spec().observation)
Observation Spec:
BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])

print('Reward Spec:')
print(env.time_step_spec().reward)
Reward Spec:
ArraySpec(shape=(), dtype=dtype('float32'), name='reward')

Il metodo action_spec() restituisce la forma, i tipi di dati e i valori consentiti delle azioni valide.

print('Action Spec:')
print(env.action_spec())
Action Spec:
BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)

Nell'ambiente Cartpole:

  • observation è un array di 4 float:
    • la posizione e la velocità del carrello
    • la posizione angolare e la velocità del polo
  • reward è un valore float scalare
  • action è un numero intero scalare con solo due possibili valori:
    • 0 - "sposta a sinistra"
    • 1 - "sposta a destra"
time_step = env.reset()
print('Time step:')
print(time_step)

action = np.array(1, dtype=np.int32)

next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
Time step:
TimeStep(step_type=array(0, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.02701063,  0.02364248, -0.03979321, -0.02606781], dtype=float32))
Next time step:
TimeStep(step_type=array(1, dtype=int32), reward=array(1., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.02748348,  0.21931183, -0.04031456, -0.3310356 ], dtype=float32))

Di solito vengono istanziati due ambienti: uno per la formazione e uno per la valutazione.

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

L'ambiente di Cartpole, come la maggior parte degli ambienti, è scritto in puro Python. Questo viene convertito in TensorFlow utilizzando il wrapper TFPyEnvironment .

L'API dell'ambiente originale utilizza gli array Numpy. I TFPyEnvironment converte questi ai Tensors per renderlo compatibile con gli agenti e le politiche tensorflow.

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

Agente

L'algoritmo utilizzato per risolvere un problema RL è rappresentato da un Agent . TF-Agents fornisce implementazioni standard di una varietà di Agents , tra cui:

L'agente DQN può essere utilizzato in qualsiasi ambiente che abbia uno spazio di azione discreto.

Il cuore di un agente DQN è un QNetwork , un modello di rete neurale che può imparare a prevedere i QValues (ritorni attesi) per tutte le azioni, data un'osservazione dall'ambiente.

Utilizzare tf_agents.networks.q_network per creare un QNetwork , passando l' observation_spec , l' action_spec e una tupla che descrive il numero e la dimensione dei livelli nascosti del modello.

fc_layer_params = (100,)

q_net = q_network.QNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)

Ora usa tf_agents.agents.dqn.dqn_agent per istanziare un DqnAgent . Oltre a time_step_spec , action_spec e QNetwork, il costruttore dell'agente richiede anche un ottimizzatore (in questo caso AdamOptimizer ), una funzione di perdita e un contatore di passi interi.

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

agent.initialize()

Politiche

Una politica definisce il modo in cui un agente agisce in un ambiente. In genere, l'obiettivo dell'apprendimento per rinforzo è addestrare il modello sottostante fino a quando la politica non produce il risultato desiderato.

In questo tutorial:

  • Il risultato desiderato è mantenere il palo in equilibrio in posizione verticale sul carrello.
  • La politica restituisce un'azione (sinistra o destra) per ogni osservazione time_step .

Gli agenti contengono due criteri:

  • agent.policy : il criterio principale utilizzato per la valutazione e la distribuzione.
  • agent.collect_policy : una seconda policy utilizzata per la raccolta dei dati.
eval_policy = agent.policy
collect_policy = agent.collect_policy

Le policy possono essere create indipendentemente dagli agenti. Ad esempio, usa tf_agents.policies.random_tf_policy per creare una policy che selezionerà casualmente un'azione per ogni time_step .

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

Per ottenere un'azione da una policy, chiama il policy.action(time_step) . Il time_step contiene l'osservazione dall'ambiente. Questo metodo restituisce un PolicyStep , che è una tupla denominata con tre componenti:

  • action - l'azione da intraprendere (in questo caso, 0 o 1 )
  • state : utilizzato per le politiche stateful (ovvero basate su RNN)
  • info - dati ausiliari, come le probabilità di registrazione delle azioni
example_environment = tf_py_environment.TFPyEnvironment(
    suite_gym.load('CartPole-v0'))
time_step = example_environment.reset()
random_policy.action(time_step)
PolicyStep(action=<tf.Tensor: shape=(1,), dtype=int64, numpy=array([1])>, state=(), info=())

Metriche e valutazione

La metrica più comune utilizzata per valutare una politica è il rendimento medio. Il ritorno è la somma dei premi ottenuti durante l'esecuzione di una politica in un ambiente per un episodio. Vengono eseguiti diversi episodi, creando un rendimento medio.

La seguente funzione calcola il rendimento medio di una politica, data la politica, l'ambiente e un numero di episodi.


def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


# See also the metrics module for standard implementations of different metrics.
# https://github.com/tensorflow/agents/tree/master/tf_agents/metrics

L'esecuzione di questo calcolo sulla random_policy mostra una prestazione di base nell'ambiente.

compute_avg_return(eval_env, random_policy, num_eval_episodes)
22.4

Replay Buffer

Il buffer di riproduzione tiene traccia dei dati raccolti dall'ambiente. Questo tutorial utilizza tf_agents.replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer , poiché è il più comune.

Il costruttore richiede le specifiche per i dati che raccoglierà. Questo è disponibile dall'agente utilizzando il metodo collect_data_spec . Sono necessarie anche la dimensione del batch e la lunghezza massima del buffer.

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length)

Per la maggior parte degli agenti, collect_data_spec è una tupla denominata Trajectory , contenente le specifiche per osservazioni, azioni, ricompense e altri elementi.

agent.collect_data_spec
Trajectory(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), observation=BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32)), action=BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0), maximum=array(1)), policy_info=(), next_step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)))
agent.collect_data_spec._fields
('step_type',
 'observation',
 'action',
 'policy_info',
 'next_step_type',
 'reward',
 'discount')

Raccolta dati

Ora esegui la politica casuale nell'ambiente per alcuni passaggi, registrando i dati nel buffer di riproduzione.


def collect_step(environment, policy, buffer):
  time_step = environment.current_time_step()
  action_step = policy.action(time_step)
  next_time_step = environment.step(action_step.action)
  traj = trajectory.from_transition(time_step, action_step, next_time_step)

  # Add trajectory to the replay buffer
  buffer.add_batch(traj)

def collect_data(env, policy, buffer, steps):
  for _ in range(steps):
    collect_step(env, policy, buffer)

collect_data(train_env, random_policy, replay_buffer, initial_collect_steps)

# This loop is so common in RL, that we provide standard implementations. 
# For more details see the drivers module.
# https://www.tensorflow.org/agents/api_docs/python/tf_agents/drivers

Il buffer di riproduzione è ora una raccolta di traiettorie.

# For the curious:
# Uncomment to peel one of these off and inspect it.
# iter(replay_buffer.as_dataset()).next()

L'agente deve accedere al buffer di riproduzione. Ciò viene fornito creando una pipeline iterabile tf.data.Dataset che fornirà dati all'agente.

Ogni riga del buffer di riproduzione memorizza solo un singolo passaggio di osservazione. Ma poiché l'agente DQN necessita sia dell'osservazione corrente che di quella successiva per calcolare la perdita, la pipeline del set di dati campionerà due righe adiacenti per ogni elemento nel batch ( num_steps=2 ).

Questo set di dati è inoltre ottimizzato eseguendo chiamate parallele e precaricando i dati.

# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, 
    sample_batch_size=batch_size, 
    num_steps=2).prefetch(3)


dataset
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/autograph/operators/control_flow.py:1004: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.

<PrefetchDataset shapes: (Trajectory(step_type=(64, 2), observation=(64, 2, 4), action=(64, 2), policy_info=(), next_step_type=(64, 2), reward=(64, 2), discount=(64, 2)), BufferInfo(ids=(64, 2), probabilities=(64,))), types: (Trajectory(step_type=tf.int32, observation=tf.float32, action=tf.int64, policy_info=(), next_step_type=tf.int32, reward=tf.float32, discount=tf.float32), BufferInfo(ids=tf.int64, probabilities=tf.float32))>
iterator = iter(dataset)

print(iterator)

<tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x7fa7cf729908>

# For the curious:
# Uncomment to see what the dataset iterator is feeding to the agent.
# Compare this representation of replay data 
# to the collection of individual trajectories shown earlier.

# iterator.next()

Formazione dell'agente

Due cose devono accadere durante il ciclo di addestramento:

  • raccogliere dati dall'ambiente
  • utilizzare questi dati per addestrare le reti neurali dell'agente

Questo esempio valuta periodicamente anche la politica e stampa il punteggio corrente.

L'esecuzione di quanto segue richiederà circa 5 minuti.


try:
  %%time
except:
  pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

  # Collect a few steps using collect_policy and save to the replay buffer.
  collect_data(train_env, agent.collect_policy, replay_buffer, collect_steps_per_iteration)

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience).loss

  step = agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    returns.append(avg_return)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
step = 200: loss = 8.992376327514648
step = 400: loss = 37.675716400146484
step = 600: loss = 13.444989204406738
step = 800: loss = 72.30130767822266
step = 1000: loss = 25.08966636657715
step = 1000: Average Return = 73.5
step = 1200: loss = 73.62870788574219
step = 1400: loss = 110.38285827636719
step = 1600: loss = 46.68946838378906
step = 1800: loss = 26.890771865844727
step = 2000: loss = 176.09652709960938
step = 2000: Average Return = 24.700000762939453
step = 2200: loss = 26.064491271972656
step = 2400: loss = 11.5733003616333
step = 2600: loss = 22.180831909179688
step = 2800: loss = 42.521949768066406
step = 3000: loss = 65.42731475830078
step = 3000: Average Return = 39.0
step = 3200: loss = 47.9765739440918
step = 3400: loss = 50.59381866455078
step = 3600: loss = 50.47264862060547
step = 3800: loss = 8.670256614685059
step = 4000: loss = 52.830589294433594
step = 4000: Average Return = 41.599998474121094
step = 4200: loss = 10.61288833618164
step = 4400: loss = 35.88026809692383
step = 4600: loss = 6.05510950088501
step = 4800: loss = 8.100227355957031
step = 5000: loss = 83.557861328125
step = 5000: Average Return = 81.69999694824219
step = 5200: loss = 9.37910270690918
step = 5400: loss = 75.47665405273438
step = 5600: loss = 128.91830444335938
step = 5800: loss = 89.63421630859375
step = 6000: loss = 18.526662826538086
step = 6000: Average Return = 96.80000305175781
step = 6200: loss = 52.134490966796875
step = 6400: loss = 114.90756225585938
step = 6600: loss = 237.4613037109375
step = 6800: loss = 73.82711029052734
step = 7000: loss = 214.91725158691406
step = 7000: Average Return = 171.1999969482422
step = 7200: loss = 7.875909805297852
step = 7400: loss = 9.217475891113281
step = 7600: loss = 358.5450439453125
step = 7800: loss = 8.870625495910645
step = 8000: loss = 91.55878448486328
step = 8000: Average Return = 191.8000030517578
step = 8200: loss = 120.59859466552734
step = 8400: loss = 10.38186264038086
step = 8600: loss = 142.29592895507812
step = 8800: loss = 8.689146995544434
step = 9000: loss = 13.707746505737305
step = 9000: Average Return = 149.6999969482422
step = 9200: loss = 15.458602905273438
step = 9400: loss = 190.49436950683594
step = 9600: loss = 31.84662628173828
step = 9800: loss = 256.5697937011719
step = 10000: loss = 164.7451171875
step = 10000: Average Return = 188.0
step = 10200: loss = 8.817024230957031
step = 10400: loss = 22.690643310546875
step = 10600: loss = 14.418478012084961
step = 10800: loss = 253.4925994873047
step = 11000: loss = 302.92608642578125
step = 11000: Average Return = 196.89999389648438
step = 11200: loss = 273.21844482421875
step = 11400: loss = 550.127685546875
step = 11600: loss = 19.929574966430664
step = 11800: loss = 256.5919189453125
step = 12000: loss = 16.680150985717773
step = 12000: Average Return = 196.6999969482422
step = 12200: loss = 19.210708618164062
step = 12400: loss = 897.2586669921875
step = 12600: loss = 373.3843994140625
step = 12800: loss = 19.439762115478516
step = 13000: loss = 441.92645263671875
step = 13000: Average Return = 196.0
step = 13200: loss = 331.1238708496094
step = 13400: loss = 30.180822372436523
step = 13600: loss = 1115.3497314453125
step = 13800: loss = 20.65155029296875
step = 14000: loss = 26.081703186035156
step = 14000: Average Return = 198.10000610351562
step = 14200: loss = 1073.404052734375
step = 14400: loss = 574.0071411132812
step = 14600: loss = 343.24481201171875
step = 14800: loss = 112.7371597290039
step = 15000: loss = 1589.832275390625
step = 15000: Average Return = 200.0
step = 15200: loss = 32.94782257080078
step = 15400: loss = 656.8715209960938
step = 15600: loss = 513.935302734375
step = 15800: loss = 1349.8228759765625
step = 16000: loss = 44.40769577026367
step = 16000: Average Return = 198.3000030517578
step = 16200: loss = 220.28024291992188
step = 16400: loss = 29.474273681640625
step = 16600: loss = 41.93848419189453
step = 16800: loss = 41.10401153564453
step = 17000: loss = 909.60888671875
step = 17000: Average Return = 200.0
step = 17200: loss = 851.1168212890625
step = 17400: loss = 48.420257568359375
step = 17600: loss = 3194.95751953125
step = 17800: loss = 641.7561645507812
step = 18000: loss = 37.21417236328125
step = 18000: Average Return = 197.10000610351562
step = 18200: loss = 59.17278289794922
step = 18400: loss = 63.922279357910156
step = 18600: loss = 76.3153076171875
step = 18800: loss = 892.47705078125
step = 19000: loss = 1339.738525390625
step = 19000: Average Return = 199.10000610351562
step = 19200: loss = 38.821475982666016
step = 19400: loss = 62.41900634765625
step = 19600: loss = 31.353261947631836
step = 19800: loss = 29.47930145263672
step = 20000: loss = 23.328723907470703
step = 20000: Average Return = 199.6999969482422

Visualizzazione

Trame

Usa matplotlib.pyplot per tracciare come la policy è migliorata durante la formazione.

Cartpole-v0 di Cartpole-v0 consiste in 200 fasi temporali. L'ambiente dà una ricompensa di +1 per ogni passo in cui il palo rimane alzato, quindi il rendimento massimo per un episodio è 200. Il grafico mostra il ritorno che aumenta verso quel massimo ogni volta che viene valutato durante l'allenamento. (Potrebbe essere un po 'instabile e non aumentare in modo monotono ogni volta.)



iterations = range(0, num_iterations + 1, eval_interval)
plt.plot(iterations, returns)
plt.ylabel('Average Return')
plt.xlabel('Iterations')
plt.ylim(top=250)
(-0.2349997997283939, 250.0)

png

Video

I grafici sono belli. Ma la cosa più eccitante è vedere un agente che esegue effettivamente un'attività in un ambiente.

Innanzitutto, crea una funzione per incorporare i video nel taccuino.

def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)

Ora ripeti alcuni episodi del gioco Cartpole con l'agente. L'ambiente Python sottostante (quello "dentro" il wrapper dell'ambiente TensorFlow) fornisce un metodo render() , che restituisce un'immagine dello stato dell'ambiente. Questi possono essere raccolti in un video.

def create_policy_eval_video(policy, filename, num_episodes=5, fps=30):
  filename = filename + ".mp4"
  with imageio.get_writer(filename, fps=fps) as video:
    for _ in range(num_episodes):
      time_step = eval_env.reset()
      video.append_data(eval_py_env.render())
      while not time_step.is_last():
        action_step = policy.action(time_step)
        time_step = eval_env.step(action_step.action)
        video.append_data(eval_py_env.render())
  return embed_mp4(filename)




create_policy_eval_video(agent.policy, "trained-agent")
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.

Per divertimento, confronta l'agente addestrato (sopra) con un agente che si muove in modo casuale. (Non funziona altrettanto bene.)

create_policy_eval_video(random_policy, "random-agent")
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.