סוכן REINFORCE

הצג באתר TensorFlow.org הפעל בגוגל קולאב צפה במקור ב-GitHub הורד מחברת

מבוא

מופעי דוגמא זו איך לאמן ולחזק סוכן על סביבת Cartpole באמצעות ספריית TF-הסוכנים, בדומה הדרכת DQN .

סביבת עגלה

אנו נדריך אותך דרך כל המרכיבים בצנרת למידת חיזוק (RL) להדרכה, הערכה ואיסוף נתונים.

להכין

אם לא התקנת את התלויות הבאות, הרץ:

sudo apt-get update
sudo apt-get install -y xvfb ffmpeg freeglut3-dev
pip install 'imageio==2.4.0'
pip install pyvirtualdisplay
pip install tf-agents[reverb]
pip install pyglet xvfbwrapper
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import IPython
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay
import reverb

import tensorflow as tf

from tf_agents.agents.reinforce import reinforce_agent
from tf_agents.drivers import py_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import py_tf_eager_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()

היפרפרמטרים

env_name = "CartPole-v0" # @param {type:"string"}
num_iterations = 250 # @param {type:"integer"}
collect_episodes_per_iteration = 2 # @param {type:"integer"}
replay_buffer_capacity = 2000 # @param {type:"integer"}

fc_layer_params = (100,)

learning_rate = 1e-3 # @param {type:"number"}
log_interval = 25 # @param {type:"integer"}
num_eval_episodes = 10 # @param {type:"integer"}
eval_interval = 50 # @param {type:"integer"}

סביבה

סביבות ב-RL מייצגות את המשימה או הבעיה שאנו מנסים לפתור. ניתן סביבות תקן נוצרו בקלות-סוכני TF באמצעות suites . יש לנו שונות suites לטעינת סביבות ממקורות כגון כושר OpenAI, אטארי, בקרת DM, וכו ', נתן שם בסביבת מחרוזת.

כעת תנו לנו לטעון את סביבת CartPole מחבילת ה-OpenAI Gym.

env = suite_gym.load(env_name)

אנחנו יכולים לעבד את הסביבה הזו כדי לראות איך היא נראית. מוט שמתנדנד חופשי מחובר לעגלה. המטרה היא להזיז את העגלה ימינה או שמאלה על מנת לשמור על המוט כלפי מעלה.

env.reset()
PIL.Image.fromarray(env.render())

png

time_step = environment.step(action) בהצהרה לוקח action בסביבה. TimeStep tuple חזר מכיל את התצפית הבאה של הסביבה גמול בגין פעולה זו. time_step_spec() ו action_spec() שיטות בסביבה להחזיר את המפרט (סוגים, צורות, גבולות) של time_step ו action בהתאמה.

print('Observation Spec:')
print(env.time_step_spec().observation)
print('Action Spec:')
print(env.action_spec())
Observation Spec:
BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])
Action Spec:
BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)

אז, אנו רואים שהתצפית היא מערך של 4 מצופים: המיקום והמהירות של העגלה, והמיקום והמהירות הזוויתית של המוט. מאז רק שתי פעולות אפשריות (לזוז ימינה או להזיז מימין), action_spec הוא סקלר כאשר 0 פירושו "לזוז ימינה" ו 1 אמצעי "צעד נכון."

time_step = env.reset()
print('Time step:')
print(time_step)

action = np.array(1, dtype=np.int32)

next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
Time step:
TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([ 0.02284177, -0.04785635,  0.04171623,  0.04942273], dtype=float32),
 'reward': array(0., dtype=float32),
 'step_type': array(0, dtype=int32)})
Next time step:
TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([ 0.02188464,  0.14664337,  0.04270469, -0.22981201], dtype=float32),
 'reward': array(1., dtype=float32),
 'step_type': array(1, dtype=int32)})

בדרך כלל אנו יוצרים שתי סביבות: אחת להדרכה ואחת להערכה. רוב הסביבות כתוב פייתון טהור, אבל הם יכולים להיות מומרים בקלות TensorFlow באמצעות TFPyEnvironment המעטפת. ה- API של הסביבה המקורית משתמש במערכי numpy, את TFPyEnvironment ממירה אלה מ / אל Tensors לך יותר בקלות אינטראקציה עם מדיניות TensorFlow וסוכנים.

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

סוֹכֵן

האלגוריתם שבו אנו משתמשים כדי לפתור בעיה RL מיוצג Agent . בנוסף הסוכן לחזק, TF-הסוכנים מספקים יישום סטנדרטי של מגוון Agents כגון DQN , DDPG , TD3 , PPO ו SAC .

כדי ליצור ולחזק סוכן, אנחנו קודם צריכים Actor Network שיכול ללמוד לחזות את הפעולה נתון תצפית מהסביבה.

אנחנו יכולים ליצור בקלות Actor Network באמצעות המפרט של התצפיות ופעולות. אנחנו יכולים לציין את שכבות הרשת אשר, בדוגמה זו, הוא fc_layer_params סט ויכוח על tuple של ints המייצג את הגדלים של כל שכבה נסתרת (ראה סעיף Hyperparameters לעיל).

actor_net = actor_distribution_network.ActorDistributionNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)

אנחנו גם צריכים optimizer להכשיר את הרשת שיצרנו כרגע, וכן train_step_counter משתנה כדי לעקוב אחר כמה פעמים את הרשת עודכנה.

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

tf_agent = reinforce_agent.ReinforceAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    actor_network=actor_net,
    optimizer=optimizer,
    normalize_returns=True,
    train_step_counter=train_step_counter)
tf_agent.initialize()

מדיניות

In-סוכני TF, מדיניות לייצג את הרעיון הסטנדרטי של מדיניות ב RL: נתון time_step לייצר פעולה או הפצה מעל פעולות. השיטה העיקרית היא policy_step = policy.action(time_step) שבו policy_step הוא tuple בשם PolicyStep(action, state, info) . policy_step.action היא action שיש להחיל על הסביבה, state מייצגת את המדינה עבור מצבים (RNN) מדיניות info עשוי להכיל מידע עזר כגון הסתברויות יומן של הפעולות.

סוכנים מכילים שני מדיניות: המדיניות העיקרית המשמשת להערכה/פריסה (agent.policy) ומדיניות נוספת המשמשת לאיסוף נתונים (agent.collect_policy).

eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy

מדדים והערכה

המדד הנפוץ ביותר המשמש להערכת פוליסה הוא התשואה הממוצעת. ההחזר הוא סכום התגמולים שהושגו בעת הפעלת פוליסה בסביבה עבור פרק, ובדרך כלל אנו ממוצעים זאת על פני מספר פרקים. אנו יכולים לחשב את מדד התשואה הממוצעת באופן הבא.

def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


# Please also see the metrics module for standard implementations of different
# metrics.

Replay Buffer

על מנת לעקוב אחר הנתונים שנאספו מהסביבה, נשתמש Reverb , מערכת שידור חוזר יעיל, להרחבה, וקל לשימוש על ידי Deepmind. הוא מאחסן נתוני ניסיון כאשר אנו אוספים מסלולים והוא נצרך במהלך האימון.

חיץ חוזר זה נבנה באמצעות מפרט המתאר את tensors כי הם יאוחסנו, אשר ניתן לקבל הסוכן באמצעות tf_agent.collect_data_spec .

table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(
      tf_agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(
      replay_buffer_signature)
table = reverb.Table(
    table_name,
    max_size=replay_buffer_capacity,
    sampler=reverb.selectors.Uniform(),
    remover=reverb.selectors.Fifo(),
    rate_limiter=reverb.rate_limiters.MinSize(1),
    signature=replay_buffer_signature)

reverb_server = reverb.Server([table])

replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
    tf_agent.collect_data_spec,
    table_name=table_name,
    sequence_length=None,
    local_server=reverb_server)

rb_observer = reverb_utils.ReverbAddEpisodeObserver(
    replay_buffer.py_client,
    table_name,
    replay_buffer_capacity
)
[reverb/cc/platform/tfrecord_checkpointer.cc:150]  Initializing TFRecordCheckpointer in /tmp/tmpem6la471.
[reverb/cc/platform/tfrecord_checkpointer.cc:385] Loading latest checkpoint from /tmp/tmpem6la471
[reverb/cc/platform/default/server.cc:71] Started replay server on port 19822

עבור רוב הסוכנים, collect_data_spec הוא Trajectory בשם tuple המכיל את התצפית, פעולה, לתגמל וכו '

איסוף נתונים

כפי ש-REINFORCE לומד מפרקים שלמים, אנו מגדירים פונקציה לאיסוף פרק באמצעות מדיניות איסוף הנתונים הנתונה ולשמור את הנתונים (תצפיות, פעולות, תגמולים וכו') כמסלולים במאגר החוזר. כאן אנו משתמשים ב-'PyDriver' כדי להפעיל את לולאת איסוף החוויה. אתה יכול ללמוד עוד על נהג TF סוכנים שלנו ההדרכה לנהגים .

def collect_episode(environment, policy, num_episodes):

  driver = py_driver.PyDriver(
    environment,
    py_tf_eager_policy.PyTFEagerPolicy(
      policy, use_tf_function=True),
    [rb_observer],
    max_episodes=num_episodes)
  initial_time_step = environment.reset()
  driver.run(initial_time_step)

הכשרת הסוכן

לולאת ההדרכה כוללת גם איסוף נתונים מהסביבה וגם ייעול רשתות הסוכן. לאורך הדרך, מדי פעם נעריך את מדיניות הסוכן כדי לראות מה שלומנו.

להלן ייקח ~3 דקות לרוץ.

try:
  %%time
except:
  pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
tf_agent.train = common.function(tf_agent.train)

# Reset the train step
tf_agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

  # Collect a few episodes using collect_policy and save to the replay buffer.
  collect_episode(
      train_py_env, tf_agent.collect_policy, collect_episodes_per_iteration)

  # Use data from the buffer and update the agent's network.
  iterator = iter(replay_buffer.as_dataset(sample_batch_size=1))
  trajectories, _ = next(iterator)
  train_loss = tf_agent.train(experience=trajectories)  

  replay_buffer.clear()

  step = tf_agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss.loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    returns.append(avg_return)
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
step = 25: loss = 0.8549901247024536
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
step = 50: loss = 1.0025296211242676
step = 50: Average Return = 23.200000762939453
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
step = 75: loss = 1.1377763748168945
step = 100: loss = 1.318871021270752
step = 100: Average Return = 159.89999389648438
step = 125: loss = 1.5053682327270508
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
step = 150: loss = 0.8051948547363281
step = 150: Average Return = 184.89999389648438
step = 175: loss = 0.6872963905334473
step = 200: loss = 2.7238712310791016
step = 200: Average Return = 186.8000030517578
step = 225: loss = 0.7495002746582031
step = 250: loss = -0.3333401679992676
step = 250: Average Return = 200.0

רְאִיָה

עלילות

אנחנו יכולים לתכנן תשואה לעומת צעדים גלובליים כדי לראות את הביצועים של הסוכן שלנו. בשינה Cartpole-v0 , הסביבה נותנת גמול של 1 לכול צעד זמן נשאר המוט למעלה, ומאז המספר המרבי של צעדים הוא 200, את ההחזר המקסימאלי האפשרי הוא גם 200.

steps = range(0, num_iterations + 1, eval_interval)
plt.plot(steps, returns)
plt.ylabel('Average Return')
plt.xlabel('Step')
plt.ylim(top=250)
(-0.2349997997283939, 250.0)

png

סרטונים

זה מועיל לדמיין את הביצועים של סוכן על ידי עיבוד הסביבה בכל שלב. לפני שנעשה זאת, תחילה ניצור פונקציה להטמעת סרטונים בקולאב הזה.

def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)

הקוד הבא מדגים את המדיניות של הסוכן לכמה פרקים:

num_episodes = 3
video_filename = 'imageio.mp4'
with imageio.get_writer(video_filename, fps=60) as video:
  for _ in range(num_episodes):
    time_step = eval_env.reset()
    video.append_data(eval_py_env.render())
    while not time_step.is_last():
      action_step = tf_agent.policy.action(time_step)
      time_step = eval_env.step(action_step.action)
      video.append_data(eval_py_env.render())

embed_mp4(video_filename)
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.
[swscaler @ 0x5604d224f3c0] Warning: data is not aligned! This can lead to a speed loss