REINFORCE đại lý

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Giới thiệu

Ví dụ này cho thấy cách để đào tạo một củng cố đại lý đối với môi trường Cartpole sử dụng thư viện TF-Đại lý, tương tự như hướng dẫn DQN .

Môi trường Cartpole

Chúng tôi sẽ hướng dẫn bạn qua tất cả các thành phần trong đường dẫn Học tập củng cố (RL) để đào tạo, đánh giá và thu thập dữ liệu.

Thành lập

Nếu bạn chưa cài đặt các phần phụ thuộc sau, hãy chạy:

sudo apt-get update
sudo apt-get install -y xvfb ffmpeg freeglut3-dev
pip install 'imageio==2.4.0'
pip install pyvirtualdisplay
pip install tf-agents[reverb]
pip install pyglet xvfbwrapper
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import IPython
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay
import reverb

import tensorflow as tf

from tf_agents.agents.reinforce import reinforce_agent
from tf_agents.drivers import py_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import py_tf_eager_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.specs import tensor_spec
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()

Siêu tham số

env_name = "CartPole-v0" # @param {type:"string"}
num_iterations = 250 # @param {type:"integer"}
collect_episodes_per_iteration = 2 # @param {type:"integer"}
replay_buffer_capacity = 2000 # @param {type:"integer"}

fc_layer_params = (100,)

learning_rate = 1e-3 # @param {type:"number"}
log_interval = 25 # @param {type:"integer"}
num_eval_episodes = 10 # @param {type:"integer"}
eval_interval = 50 # @param {type:"integer"}

Môi trường

Các môi trường trong RL đại diện cho nhiệm vụ hoặc vấn đề mà chúng tôi đang cố gắng giải quyết. Môi trường tiêu chuẩn có thể dễ dàng tạo ra trong TF-Đại lý sử dụng suites . Chúng tôi có khác nhau suites cho tải môi trường từ các nguồn như OpenAI phòng tập thể dục, Atari, DM kiểm soát, vv cho một tên môi trường chuỗi.

Bây giờ chúng ta hãy tải môi trường CartPole từ bộ OpenAI Gym.

env = suite_gym.load(env_name)

Chúng ta có thể kết xuất môi trường này để xem nó trông như thế nào. Một cây sào đung đưa tự do được gắn vào một xe đẩy. Mục đích là di chuyển xe sang phải hoặc trái để giữ cho cột hướng lên trên.

env.reset()
PIL.Image.fromarray(env.render())

png

Các time_step = environment.step(action) tuyên bố sẽ đưa action trong môi trường. Các TimeStep tuple trở chứa quan sát tiếp theo của môi trường và phần thưởng cho hành động đó. Các time_step_spec()action_spec() phương pháp trong môi trường trả lại thông số kỹ thuật (chủng loại, hình dạng, giới hạn) của time_stepaction tương ứng.

print('Observation Spec:')
print(env.time_step_spec().observation)
print('Action Spec:')
print(env.action_spec())
Observation Spec:
BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])
Action Spec:
BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)

Vì vậy, chúng ta thấy rằng quan sát là một mảng gồm 4 vật nổi: vị trí và vận tốc của xe đẩy, vị trí góc và vận tốc của cực. Vì chỉ có hai hành động có thể xảy ra (di chuyển sang trái hoặc di chuyển bên phải), các action_spec là một vô hướng trong đó 0 có nghĩa là "di chuyển sang trái" và 1 có nghĩa là "bước đi đúng đắn."

time_step = env.reset()
print('Time step:')
print(time_step)

action = np.array(1, dtype=np.int32)

next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
Time step:
TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([ 0.02284177, -0.04785635,  0.04171623,  0.04942273], dtype=float32),
 'reward': array(0., dtype=float32),
 'step_type': array(0, dtype=int32)})
Next time step:
TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([ 0.02188464,  0.14664337,  0.04270469, -0.22981201], dtype=float32),
 'reward': array(1., dtype=float32),
 'step_type': array(1, dtype=int32)})

Thông thường chúng tôi tạo ra hai môi trường: một để đào tạo và một để đánh giá. Hầu hết các môi trường được viết bằng python tinh khiết, nhưng họ có thể dễ dàng chuyển đổi sang sử dụng TensorFlow TFPyEnvironment wrapper. API môi trường ban đầu của sử dụng mảng NumPy, các TFPyEnvironment chuyển đổi những đến / từ Tensors để bạn có thể dễ dàng hơn tương tác với các chính sách và các đại lý TensorFlow.

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

Đại lý

Các thuật toán mà chúng tôi sử dụng để giải quyết một vấn đề RL được biểu diễn dưới dạng một Agent . Ngoài các đại lý củng cố, TF-Đại lý cung cấp triển khai tiêu chuẩn của một loạt các Agents như DQN , DDPG , TD3 , PPOSAC .

Để tạo một củng cố Agent, đầu tiên chúng ta cần một Actor Network đó có thể học để dự đoán hành động đưa ra một quan sát từ môi trường.

Chúng ta có thể dễ dàng tạo một Actor Network sử dụng các thông số kỹ thuật của các quan sát và hành động. Chúng tôi có thể xác định các lớp trong mạng đó, trong ví dụ này, là fc_layer_params bộ tranh luận với một tuple của ints hiện các kích thước của mỗi lớp ẩn (xem phần siêu tham số trên).

actor_net = actor_distribution_network.ActorDistributionNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)

Chúng ta cũng cần một optimizer để đào tạo mạng, chúng tôi vừa tạo ra, và một train_step_counter biến để theo dõi bao nhiêu lần mạng đã được cập nhật.

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

tf_agent = reinforce_agent.ReinforceAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    actor_network=actor_net,
    optimizer=optimizer,
    normalize_returns=True,
    train_step_counter=train_step_counter)
tf_agent.initialize()

Chính sách

Trong TF-Đại lý, chính sách đại diện cho quan điểm tiêu chuẩn của chính sách trong RL: cho một time_step tạo ra một hành động hoặc một bản phân phối qua hành động. Phương pháp chính là policy_step = policy.action(time_step) nơi policy_step là một tên tuple PolicyStep(action, state, info) . Các policy_step.actionaction được áp dụng đối với môi trường, state đại diện cho nhà nước cho stateful chính sách và (RNN) info có thể chứa thông tin phụ trợ như xác suất log của hành động.

Đại lý chứa hai chính sách: chính sách chính được sử dụng để đánh giá / triển khai (agent.policy) và một chính sách khác được sử dụng để thu thập dữ liệu (agent.collect_policy).

eval_policy = tf_agent.policy
collect_policy = tf_agent.collect_policy

Số liệu và Đánh giá

Số liệu phổ biến nhất được sử dụng để đánh giá một chính sách là lợi tức trung bình. Lợi nhuận là tổng số phần thưởng nhận được khi chạy chính sách trong môi trường cho một tập và chúng tôi thường tính trung bình con số này qua một vài tập. Chúng tôi có thể tính toán số liệu lợi nhuận trung bình như sau.

def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


# Please also see the metrics module for standard implementations of different
# metrics.

Replay Buffer

Để theo dõi các dữ liệu thu thập từ môi trường, chúng tôi sẽ sử dụng Reverb , một hệ thống phát lại hiệu quả, mở rộng, và dễ dàng sử dụng bởi Deepmind. Nó lưu trữ dữ liệu kinh nghiệm khi chúng tôi thu thập quỹ đạo và được tiêu thụ trong quá trình đào tạo.

Đệm phát lại này được xây dựng bằng kỹ thuật mô tả tensors mà phải được lưu trữ, có thể được lấy từ các đại lý sử dụng tf_agent.collect_data_spec .

table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(
      tf_agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(
      replay_buffer_signature)
table = reverb.Table(
    table_name,
    max_size=replay_buffer_capacity,
    sampler=reverb.selectors.Uniform(),
    remover=reverb.selectors.Fifo(),
    rate_limiter=reverb.rate_limiters.MinSize(1),
    signature=replay_buffer_signature)

reverb_server = reverb.Server([table])

replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
    tf_agent.collect_data_spec,
    table_name=table_name,
    sequence_length=None,
    local_server=reverb_server)

rb_observer = reverb_utils.ReverbAddEpisodeObserver(
    replay_buffer.py_client,
    table_name,
    replay_buffer_capacity
)
[reverb/cc/platform/tfrecord_checkpointer.cc:150]  Initializing TFRecordCheckpointer in /tmp/tmpem6la471.
[reverb/cc/platform/tfrecord_checkpointer.cc:385] Loading latest checkpoint from /tmp/tmpem6la471
[reverb/cc/platform/default/server.cc:71] Started replay server on port 19822

Đối với hầu hết các đại lý, các collect_data_spec là một Trajectory tên là tuple chứa các quan sát, hành động, khen thưởng, vv

Thu thập dữ liệu

Khi REINFORCE học từ toàn bộ các tập, chúng tôi xác định một hàm để thu thập một tập bằng cách sử dụng chính sách thu thập dữ liệu đã cho và lưu dữ liệu (quan sát, hành động, phần thưởng, v.v.) dưới dạng quỹ đạo trong bộ đệm phát lại. Ở đây chúng tôi đang sử dụng 'PyDriver' để chạy vòng lặp thu thập kinh nghiệm. Bạn có thể tìm hiểu thêm về lái xe TF Đại lý tại chúng tôi lái xe hướng dẫn .

def collect_episode(environment, policy, num_episodes):

  driver = py_driver.PyDriver(
    environment,
    py_tf_eager_policy.PyTFEagerPolicy(
      policy, use_tf_function=True),
    [rb_observer],
    max_episodes=num_episodes)
  initial_time_step = environment.reset()
  driver.run(initial_time_step)

Đào tạo đại lý

Vòng huấn luyện bao gồm cả việc thu thập dữ liệu từ môi trường và tối ưu hóa mạng của tác nhân. Trên đường đi, chúng tôi sẽ thỉnh thoảng đánh giá chính sách của đại lý để xem chúng tôi đang hoạt động như thế nào.

Phần sau sẽ mất ~ 3 phút để chạy.

try:
  %%time
except:
  pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
tf_agent.train = common.function(tf_agent.train)

# Reset the train step
tf_agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

  # Collect a few episodes using collect_policy and save to the replay buffer.
  collect_episode(
      train_py_env, tf_agent.collect_policy, collect_episodes_per_iteration)

  # Use data from the buffer and update the agent's network.
  iterator = iter(replay_buffer.as_dataset(sample_batch_size=1))
  trajectories, _ = next(iterator)
  train_loss = tf_agent.train(experience=trajectories)  

  replay_buffer.clear()

  step = tf_agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss.loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, tf_agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    returns.append(avg_return)
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
step = 25: loss = 0.8549901247024536
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
step = 50: loss = 1.0025296211242676
step = 50: Average Return = 23.200000762939453
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
step = 75: loss = 1.1377763748168945
step = 100: loss = 1.318871021270752
step = 100: Average Return = 159.89999389648438
step = 125: loss = 1.5053682327270508
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (20164) so Table uniform_table is accessed directly without gRPC.
step = 150: loss = 0.8051948547363281
step = 150: Average Return = 184.89999389648438
step = 175: loss = 0.6872963905334473
step = 200: loss = 2.7238712310791016
step = 200: Average Return = 186.8000030517578
step = 225: loss = 0.7495002746582031
step = 250: loss = -0.3333401679992676
step = 250: Average Return = 200.0

Hình dung

Lô đất

Chúng tôi có thể vẽ biểu đồ trở lại so với các bước toàn cầu để xem hiệu suất của đại lý của chúng tôi. Trong Cartpole-v0 , môi trường cung cấp cho một phần thưởng của 1 cho mỗi bước thời gian ở lại cực lên, và vì số lượng tối đa các bước là 200, có thể tối đa lợi nhuận cũng là 200.

steps = range(0, num_iterations + 1, eval_interval)
plt.plot(steps, returns)
plt.ylabel('Average Return')
plt.xlabel('Step')
plt.ylim(top=250)
(-0.2349997997283939, 250.0)

png

Video

Sẽ rất hữu ích nếu bạn hình dung hiệu suất của một tác nhân bằng cách hiển thị môi trường ở mỗi bước. Trước khi làm điều đó, trước tiên chúng ta hãy tạo một chức năng để nhúng video vào chuyên mục này.

def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)

Đoạn mã sau hiển thị chính sách của đại lý trong một vài tập:

num_episodes = 3
video_filename = 'imageio.mp4'
with imageio.get_writer(video_filename, fps=60) as video:
  for _ in range(num_episodes):
    time_step = eval_env.reset()
    video.append_data(eval_py_env.render())
    while not time_step.is_last():
      action_step = tf_agent.policy.action(time_step)
      time_step = eval_env.step(action_step.action)
      video.append_data(eval_py_env.render())

embed_mp4(video_filename)
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.
[swscaler @ 0x5604d224f3c0] Warning: data is not aligned! This can lead to a speed loss