לשמור את התאריך! קלט / פלט של Google חוזר 18-20 במאי הירשם עכשיו
דף זה תורגם על ידי Cloud Translation API.
Switch to English

הכשיר רשת Q Deep עם סוכני TF

צפה ב- TensorFlow.org הפעל בגוגל קולאב צפה במקור ב- GitHubהורד מחברת

מבוא

דוגמה זו מראה כיצד להכשיר סוכן DQN (Deep Q Networks) בסביבת Cartpole באמצעות ספריית TF-Agents.

סביבת קרטולים

זה ילווה אותך בכל הרכיבים בצינור למידת חיזוק (RL) לצורך הכשרה, הערכה ואיסוף נתונים.

כדי להפעיל קוד זה בשידור חי, לחץ על הקישור 'הפעל ב- Google Colab' למעלה.

להכין

אם לא התקנת את התלות הבאות, הפעל:

sudo apt-get install -y xvfb ffmpeg
pip install -q 'imageio==2.4.0'
pip install -q pyvirtualdisplay
pip install -q tf-agents
from __future__ import absolute_import, division, print_function

import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay

import tensorflow as tf

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import sequential
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.specs import tensor_spec
from tf_agents.utils import common
# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
tf.version.VERSION
'2.4.1'

היפרפרמטרים

num_iterations = 20000 # @param {type:"integer"}

initial_collect_steps = 100  # @param {type:"integer"} 
collect_steps_per_iteration = 1  # @param {type:"integer"}
replay_buffer_max_length = 100000  # @param {type:"integer"}

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
log_interval = 200  # @param {type:"integer"}

num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}

סביבה

בלמידת חיזוק (RML), סביבה מייצגת את המשימה או הבעיה שיש לפתור. ניתן ליצור סביבות סטנדרטיות ב- TF-Agents באמצעות סוויטות tf_agents.environments . ל- TF-Agents יש סוויטות לטעינת סביבות ממקורות כמו OpenAI Gym, Atari ו- DM Control.

טען את סביבת CartPole מחבילת הכושר OpenAI.

env_name = 'CartPole-v0'
env = suite_gym.load(env_name)

אתה יכול לעבד את הסביבה הזו כדי לראות איך היא נראית. עגלה המתנדנדת חופשית מחוברת לעגלה. המטרה היא להזיז את העגלה ימינה או שמאלה על מנת לשמור על המוט כלפי מעלה.

env.reset()
PIL.Image.fromarray(env.render())

png

environment.step השיטה לוקחת action בסביבה ומחזירת TimeStep tuple המכיל את התצפית הבאה של הסביבה ואת הגמול עבור הפעולה.

שיטת time_step_spec() מחזירה את המפרט TimeStep של TimeStep . תכונת observation שלה מראה את צורת התצפיות, סוגי הנתונים וטווחי הערכים המותרים. תכונת reward מציגה את אותם הפרטים עבור התגמול.

print('Observation Spec:')
print(env.time_step_spec().observation)
Observation Spec:
BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])
print('Reward Spec:')
print(env.time_step_spec().reward)
Reward Spec:
ArraySpec(shape=(), dtype=dtype('float32'), name='reward')

שיטת action_spec() מחזירה את הצורה, סוגי הנתונים והערכים המותרים של פעולות חוקיות.

print('Action Spec:')
print(env.action_spec())
Action Spec:
BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)

בסביבת הקרטפול:

  • observation היא מערך של 4 צפים:
    • המיקום והמהירות של העגלה
    • המיקום הזוויתי ומהירות המוט
  • reward הוא ערך צף סקלרי
  • action היא מספר שלם סקלרי עם שני ערכים אפשריים בלבד:
    • 0 - "זז שמאלה"
    • 1 - "זז ימינה"
time_step = env.reset()
print('Time step:')
print(time_step)

action = np.array(1, dtype=np.int32)

next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
Time step:
TimeStep(step_type=array(0, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([-0.03850375, -0.01932749,  0.00336228, -0.04097027], dtype=float32))
Next time step:
TimeStep(step_type=array(1, dtype=int32), reward=array(1., dtype=float32), discount=array(1., dtype=float32), observation=array([-0.03889029,  0.17574608,  0.00254287, -0.3325905 ], dtype=float32))

בדרך כלל מייצרות שתי סביבות: אחת להכשרה ואחת להערכה.

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

סביבת Cartpole, כמו רוב הסביבות, כתובה בפייתון טהור. זה מומר ל- TensorFlow באמצעות עטיפת TFPyEnvironment .

ה- API של הסביבה המקורית משתמש במערכים Numpy. ה- TFPyEnvironment ממיר אותם Tensors כדי להפוך אותו לתואם לסוכני ומדיניות Tensorflow.

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

סוֹכֵן

האלגוריתם המשמש לפתרון בעיית RL מיוצג על ידי Agent . TF-Agents מספקת יישומים סטנדרטיים של מגוון Agents , כולל:

ניתן להשתמש בסוכן DQN בכל סביבה שיש בה מרחב פעולה נפרד.

בלבו של סוכן DQN עומד QNetwork , מודל רשת עצבי שיכול ללמוד לחזות QValues (תשואות צפויות) לכל הפעולות, בהתחשב בתצפית מהסביבה.

נשתמש ב- tf_agents.networks. כדי ליצור QNetwork . הרשת תורכב מרצף של שכבות tf.keras.layers.Dense , כאשר לשכבה הסופית תהיה פלט אחד לכל פעולה אפשרית.

fc_layer_params = (100, 50)
action_tensor_spec = tensor_spec.from_spec(env.action_spec())
num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

# Define a helper function to create Dense layers configured with the right
# activation and kernel initializer.
def dense_layer(num_units):
  return tf.keras.layers.Dense(
      num_units,
      activation=tf.keras.activations.relu,
      kernel_initializer=tf.keras.initializers.VarianceScaling(
          scale=2.0, mode='fan_in', distribution='truncated_normal'))

# QNetwork consists of a sequence of Dense layers followed by a dense layer
# with `num_actions` units to generate one q_value per available action as
# it's output.
dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
q_values_layer = tf.keras.layers.Dense(
    num_actions,
    activation=None,
    kernel_initializer=tf.keras.initializers.RandomUniform(
        minval=-0.03, maxval=0.03),
    bias_initializer=tf.keras.initializers.Constant(-0.2))
q_net = sequential.Sequential(dense_layers + [q_values_layer])

כעת השתמש ב- tf_agents.agents.dqn.dqn_agent כדי ליצור DqnAgent . בנוסף ל- time_step_spec , action_spec ו- QNetwork, בונה הסוכן דורש גם אופטימיזציה (במקרה זה, AdamOptimizer ), פונקציית הפסד AdamOptimizer צעדים שלמים.

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

agent.initialize()

מדיניות

מדיניות מגדירה את האופן שבו סוכן פועל בסביבה. בדרך כלל, המטרה של למידת חיזוק היא להכשיר את המודל הבסיסי עד שהמדיניות תניב את התוצאה הרצויה.

במדריך זה:

  • התוצאה הרצויה היא שמירה על איזון המוט זקוף מעל העגלה.
  • המדיניות מחזירה פעולה (שמאלה או ימינה) לכל תצפית time_step .

סוכנים מכילים שתי מדיניות:

  • agent.policy - המדיניות העיקרית המשמשת להערכה ולפריסה.
  • agent.collect_policy - מדיניות שנייה המשמשת לאיסוף נתונים.
eval_policy = agent.policy
collect_policy = agent.collect_policy

ניתן ליצור מדיניות ללא תלות בסוכנים. לדוגמה, השתמש ב- tf_agents.policies.random_tf_policy כדי ליצור מדיניות שתבחר באופן אקראי פעולה לכל time_step .

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

כדי להשיג פעולה ממדיניות, התקשר policy.action(time_step) . time_step מכיל תצפית מהסביבה. שיטה זו מחזירה PolicyStep , שהוא שם tuple בעל שלושה מרכיבים:

  • action - הפעולה שיש לבצע (במקרה זה, 0 או 1 )
  • state - משמש למדיניות סטטית (כלומר מבוססת RNN)
  • info - נתוני עזר, כגון הסתברויות ביומן פעולות
example_environment = tf_py_environment.TFPyEnvironment(
    suite_gym.load('CartPole-v0'))
time_step = example_environment.reset()
random_policy.action(time_step)
PolicyStep(action=<tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>, state=(), info=())

מדדים והערכה

המדד הנפוץ ביותר המשמש להערכת מדיניות הוא התשואה הממוצעת. ההחזר הוא סכום התגמולים המתקבלים בעת ניהול מדיניות בסביבה לפרק. מספר פרקים מנוהלים ויוצרים תשואה ממוצעת.

הפונקציה הבאה מחשבת את התשואה הממוצעת של מדיניות, בהתחשב במדיניות, בסביבה ובמספר פרקים.

def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


# See also the metrics module for standard implementations of different metrics.
# https://github.com/tensorflow/agents/tree/master/tf_agents/metrics

הפעלת חישוב זה random_policy מראה ביצועים בסיסיים בסביבה.

compute_avg_return(eval_env, random_policy, num_eval_episodes)
20.7

הפעל מחדש את המאגר

חיץ השידור החוזר עוקב אחר הנתונים שנאספו מהסביבה. הדרכה זו משתמשת ב- tf_agents.replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer , מכיוון שהוא הנפוץ ביותר.

הקונסטרוקטור דורש מפרט עבור הנתונים שהוא אוסף. זה זמין מהסוכן בשיטת collect_data_spec . גודל האצווה ואורך החיץ המרבי נדרשים גם כן.

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length)

עבור מרבית הסוכנים, collect_data_spec הוא תואר בשם Trajectory המכיל את המפרט לתצפיות, פעולות, תגמולים ופריטים אחרים.

agent.collect_data_spec
Trajectory(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), observation=BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32)), action=BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0), maximum=array(1)), policy_info=(), next_step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)))
agent.collect_data_spec._fields
('step_type',
 'observation',
 'action',
 'policy_info',
 'next_step_type',
 'reward',
 'discount')

איסוף נתונים

כעת בצע את המדיניות האקראית בסביבה לכמה צעדים, ותעד את הנתונים במאגר ההפעלה החוזרת.

def collect_step(environment, policy, buffer):
  time_step = environment.current_time_step()
  action_step = policy.action(time_step)
  next_time_step = environment.step(action_step.action)
  traj = trajectory.from_transition(time_step, action_step, next_time_step)

  # Add trajectory to the replay buffer
  buffer.add_batch(traj)

def collect_data(env, policy, buffer, steps):
  for _ in range(steps):
    collect_step(env, policy, buffer)

collect_data(train_env, random_policy, replay_buffer, initial_collect_steps)

# This loop is so common in RL, that we provide standard implementations. 
# For more details see the drivers module.
# https://www.tensorflow.org/agents/api_docs/python/tf_agents/drivers

חיץ השידור החוזר הוא כעת אוסף של מסלולים.

# For the curious:
# Uncomment to peel one of these off and inspect it.
# iter(replay_buffer.as_dataset()).next()

הסוכן זקוק לגישה למאגר ההפעלה החוזרת. זה מסופק על ידי יצירת צינורtf.data.Dataset שניתןtf.data.Dataset עליו שיאכיל נתונים לסוכן.

כל שורה במאגר ההפעלה החוזרת מאחסנת רק שלב תצפית יחיד. אך מכיוון שסוכן DQN זקוק לתצפית הנוכחית ולתצפית הבאה כדי לחשב את האובדן, צינור מערך הנתונים num_steps=2 שתי שורות סמוכות לכל פריט באצווה ( num_steps=2 ).

מערך נתונים זה מותאם גם על ידי הפעלת שיחות מקבילות ואיסוף נתונים מראש.

# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, 
    sample_batch_size=batch_size, 
    num_steps=2).prefetch(3)


dataset
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/autograph/operators/control_flow.py:1218: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.
<PrefetchDataset shapes: (Trajectory(step_type=(64, 2), observation=(64, 2, 4), action=(64, 2), policy_info=(), next_step_type=(64, 2), reward=(64, 2), discount=(64, 2)), BufferInfo(ids=(64, 2), probabilities=(64,))), types: (Trajectory(step_type=tf.int32, observation=tf.float32, action=tf.int64, policy_info=(), next_step_type=tf.int32, reward=tf.float32, discount=tf.float32), BufferInfo(ids=tf.int64, probabilities=tf.float32))>
iterator = iter(dataset)
print(iterator)
<tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x7fd45c8c1d30>
# For the curious:
# Uncomment to see what the dataset iterator is feeding to the agent.
# Compare this representation of replay data 
# to the collection of individual trajectories shown earlier.

# iterator.next()

הכשרת הסוכן

שני דברים חייבים לקרות במהלך לולאת האימונים:

  • לאסוף נתונים מהסביבה
  • השתמש בנתונים האלה כדי לאמן את רשתות העצביות של הסוכן

דוגמה זו גם מעריכה מעת לעת את המדיניות ומדפיסה את הציון הנוכחי.

להמשך הפעולה ייקח ~ 5 דקות.

try:
  %%time
except:
  pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

  # Collect a few steps using collect_policy and save to the replay buffer.
  collect_data(train_env, agent.collect_policy, replay_buffer, collect_steps_per_iteration)

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience).loss

  step = agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    returns.append(avg_return)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
step = 200: loss = 7.1376953125
step = 400: loss = 10.289593696594238
step = 600: loss = 19.545475006103516
step = 800: loss = 36.43363952636719
step = 1000: loss = 5.419467449188232
step = 1000: Average Return = 54.20000076293945
step = 1200: loss = 34.57895278930664
step = 1400: loss = 46.99689483642578
step = 1600: loss = 40.32734298706055
step = 1800: loss = 47.736488342285156
step = 2000: loss = 49.739967346191406
step = 2000: Average Return = 190.0
step = 2200: loss = 844.77099609375
step = 2400: loss = 4119.0439453125
step = 2600: loss = 255.4365234375
step = 2800: loss = 143.8880157470703
step = 3000: loss = 225.29849243164062
step = 3000: Average Return = 200.0
step = 3200: loss = 599.4447021484375
step = 3400: loss = 1019.076171875
step = 3600: loss = 1293.42919921875
step = 3800: loss = 2320.72216796875
step = 4000: loss = 41184.2109375
step = 4000: Average Return = 200.0
step = 4200: loss = 4963.77880859375
step = 4400: loss = 5559.41552734375
step = 4600: loss = 5049.33203125
step = 4800: loss = 78073.1328125
step = 5000: loss = 36796.953125
step = 5000: Average Return = 200.0
step = 5200: loss = 9767.94140625
step = 5400: loss = 22039.0625
step = 5600: loss = 22485.873046875
step = 5800: loss = 388653.09375
step = 6000: loss = 216484.203125
step = 6000: Average Return = 200.0
step = 6200: loss = 25190.736328125
step = 6400: loss = 34655.4296875
step = 6600: loss = 46708.58984375
step = 6800: loss = 53557.33984375
step = 7000: loss = 32926.9296875
step = 7000: Average Return = 200.0
step = 7200: loss = 45613.109375
step = 7400: loss = 37991.77734375
step = 7600: loss = 69465.1796875
step = 7800: loss = 82083.4140625
step = 8000: loss = 45056.90234375
step = 8000: Average Return = 200.0
step = 8200: loss = 82084.65625
step = 8400: loss = 123847.4375
step = 8600: loss = 78031.5
step = 8800: loss = 67286.640625
step = 9000: loss = 227665.625
step = 9000: Average Return = 200.0
step = 9200: loss = 189473.8125
step = 9400: loss = 175909.5625
step = 9600: loss = 219817.296875
step = 9800: loss = 402358.21875
step = 10000: loss = 295174.6875
step = 10000: Average Return = 200.0
step = 10200: loss = 297499.125
step = 10400: loss = 328237.8125
step = 10600: loss = 6439973.5
step = 10800: loss = 196471.09375
step = 11000: loss = 412917.0
step = 11000: Average Return = 200.0
step = 11200: loss = 224290.40625
step = 11400: loss = 215624.25
step = 11600: loss = 428307.0625
step = 11800: loss = 979577.5625
step = 12000: loss = 192212.125
step = 12000: Average Return = 200.0
step = 12200: loss = 395367.75
step = 12400: loss = 24781564.0
step = 12600: loss = 10260239.0
step = 12800: loss = 540727.875
step = 13000: loss = 638093.25
step = 13000: Average Return = 200.0
step = 13200: loss = 774659.4375
step = 13400: loss = 725100.875
step = 13600: loss = 759670.625
step = 13800: loss = 650109.0625
step = 14000: loss = 714944.5
step = 14000: Average Return = 200.0
step = 14200: loss = 765532.1875
step = 14400: loss = 559825.3125
step = 14600: loss = 1272693.375
step = 14800: loss = 1612590.75
step = 15000: loss = 1986240.25
step = 15000: Average Return = 200.0
step = 15200: loss = 1077601.125
step = 15400: loss = 1231381.125
step = 15600: loss = 2455884.5
step = 15800: loss = 23788086.0
step = 16000: loss = 1422684.0
step = 16000: Average Return = 200.0
step = 16200: loss = 2394509.0
step = 16400: loss = 4006896.0
step = 16600: loss = 1646453.625
step = 16800: loss = 1340689.125
step = 17000: loss = 26381168.0
step = 17000: Average Return = 200.0
step = 17200: loss = 2301365.75
step = 17400: loss = 589999.5
step = 17600: loss = 1951961.5
step = 17800: loss = 717706.3125
step = 18000: loss = 1537313.875
step = 18000: Average Return = 200.0
step = 18200: loss = 2864821.0
step = 18400: loss = 1825217.5
step = 18600: loss = 2375406.75
step = 18800: loss = 4396603.0
step = 19000: loss = 2847116.0
step = 19000: Average Return = 200.0
step = 19200: loss = 2379508.5
step = 19400: loss = 785688.625
step = 19600: loss = 2396494.0
step = 19800: loss = 3612146.75
step = 20000: loss = 3842961.75
step = 20000: Average Return = 200.0

רְאִיָה

עלילות

השתמש ב- matplotlib.pyplot כדי matplotlib.pyplot כיצד המדיניות השתפרה במהלך האימונים.

איטרציה אחת של Cartpole-v0 כוללת 200 שלבי זמן. הסביבה מעניקה תגמול של +1 עבור כל שלב שהמוט נשאר למעלה, ולכן התשואה המקסימאלית לפרק אחד היא 200. בתרשימים עולה כי התשואה עולה לעבר המקסימום בכל פעם שהוא מוערך במהלך האימון. (זה יכול להיות קצת לא יציב ולא יגדל בצורה מונוטונית בכל פעם.)

iterations = range(0, num_iterations + 1, eval_interval)
plt.plot(iterations, returns)
plt.ylabel('Average Return')
plt.xlabel('Iterations')
plt.ylim(top=250)
(1.3400002002716054, 250.0)

png

סרטונים

תרשימים הם נחמדים. אבל יותר מרגש הוא לראות סוכן שבאמת מבצע משימה בסביבה.

ראשית, צור פונקציה להטמעת סרטונים במחברת.

def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)

עכשיו חזרו דרך כמה פרקים של המשחק Cartpole עם הסוכן. סביבת הפיתון הבסיסית (זו "בתוך" עטיפת הסביבה TensorFlow) מספקת שיטת render() , שמפיקה תמונה של מצב הסביבה. את אלה ניתן לאסוף לסרטון.

def create_policy_eval_video(policy, filename, num_episodes=5, fps=30):
  filename = filename + ".mp4"
  with imageio.get_writer(filename, fps=fps) as video:
    for _ in range(num_episodes):
      time_step = eval_env.reset()
      video.append_data(eval_py_env.render())
      while not time_step.is_last():
        action_step = policy.action(time_step)
        time_step = eval_env.step(action_step.action)
        video.append_data(eval_py_env.render())
  return embed_mp4(filename)




create_policy_eval_video(agent.policy, "trained-agent")
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.

בכיף, השווה את הסוכן המאומן (לעיל) לסוכן הנע באופן אקראי. (זה לא עושה גם כן.)

create_policy_eval_video(random_policy, "random-agent")
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.