Đào tạo Mạng Q sâu với TF-Agents

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Giới thiệu

Ví dụ này cho thấy cách để đào tạo một DQN (Sâu Q Networks) đại lý đối với môi trường Cartpole sử dụng thư viện TF-Đại lý.

Môi trường Cartpole

Nó sẽ hướng dẫn bạn qua tất cả các thành phần trong một đường dẫn Học tập củng cố (RL) để đào tạo, đánh giá và thu thập dữ liệu.

Để chạy mã này trực tiếp, hãy nhấp vào liên kết 'Chạy trong Google Colab' ở trên.

Cài đặt

Nếu bạn chưa cài đặt các phần phụ thuộc sau, hãy chạy:

sudo apt-get update
sudo apt-get install -y xvfb ffmpeg freeglut3-dev
pip install 'imageio==2.4.0'
pip install pyvirtualdisplay
pip install tf-agents[reverb]
pip install pyglet
from __future__ import absolute_import, division, print_function

import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay
import reverb

import tensorflow as tf

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import py_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import sequential
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.trajectories import trajectory
from tf_agents.specs import tensor_spec
from tf_agents.utils import common
# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
tf.version.VERSION
'2.6.0'

Siêu tham số

num_iterations = 20000 # @param {type:"integer"}

initial_collect_steps = 100  # @param {type:"integer"}
collect_steps_per_iteration =   1# @param {type:"integer"}
replay_buffer_max_length = 100000  # @param {type:"integer"}

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
log_interval = 200  # @param {type:"integer"}

num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}

Môi trường

Trong Học tập củng cố (RL), một môi trường đại diện cho nhiệm vụ hoặc vấn đề cần giải quyết. Môi trường tiêu chuẩn có thể được tạo ra trong TF-Đại lý sử dụng tf_agents.environments dãy phòng. TF-Agents có các dãy phòng dành cho môi trường tải từ các nguồn như OpenAI Gym, Atari và DM Control.

Tải môi trường CartPole từ bộ OpenAI Gym.

env_name = 'CartPole-v0'
env = suite_gym.load(env_name)

Bạn có thể kết xuất môi trường này để xem nó trông như thế nào. Một cây sào đung đưa tự do được gắn vào một xe đẩy. Mục đích là di chuyển xe sang phải hoặc trái để giữ cho cột hướng lên trên.

env.reset()
PIL.Image.fromarray(env.render())

png

Các environment.step phương pháp có một action trong môi trường và trả về một TimeStep tuple chứa các quan sát tiếp theo của môi trường và phần thưởng cho hành động.

Các time_step_spec() phương thức trả về thông số kỹ thuật cho TimeStep tuple. Nó observation thuộc tính hiển thị hình dạng của các quan sát, các kiểu dữ liệu, và phạm vi của các giá trị được phép. Các reward thuộc tính cho thấy các chi tiết tương tự cho các phần thưởng.

print('Observation Spec:')
print(env.time_step_spec().observation)
Observation Spec:
BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])
print('Reward Spec:')
print(env.time_step_spec().reward)
Reward Spec:
ArraySpec(shape=(), dtype=dtype('float32'), name='reward')

Các action_spec() phương thức trả về hình dạng, kiểu dữ liệu, và các giá trị cho phép của các hành động hợp lệ.

print('Action Spec:')
print(env.action_spec())
Action Spec:
BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)

Trong môi trường Cartpole:

  • observation là một mảng của 4 nổi:
    • vị trí và vận tốc của xe đẩy
    • vị trí góc và vận tốc của cực
  • reward là một giá trị vô hướng phao
  • action là một số nguyên vô hướng với chỉ có hai giá trị có thể:
    • 0 - "di chuyển trái"
    • 1 - "bước đi đúng đắn"
time_step = env.reset()
print('Time step:')
print(time_step)

action = np.array(1, dtype=np.int32)

next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
Time step:
TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([-0.02109759, -0.00062286,  0.04167245, -0.03825747], dtype=float32),
 'reward': array(0., dtype=float32),
 'step_type': array(0, dtype=int32)})
Next time step:
TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([-0.02111005,  0.1938775 ,  0.0409073 , -0.31750655], dtype=float32),
 'reward': array(1., dtype=float32),
 'step_type': array(1, dtype=int32)})

Thông thường hai môi trường được khởi tạo: một để đào tạo và một để đánh giá.

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

Môi trường Cartpole, giống như hầu hết các môi trường, được viết bằng Python thuần túy. Này được chuyển đổi sang sử dụng TensorFlow TFPyEnvironment wrapper.

API của môi trường gốc sử dụng mảng Numpy. Các TFPyEnvironment cải này để Tensors để làm cho nó tương thích với các đại lý và chính sách Tensorflow.

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

Đại lý

Các thuật toán sử dụng để giải quyết một vấn đề RL được đại diện bởi một Agent . TF-Đại lý cung cấp triển khai tiêu chuẩn của một loạt các Agents , bao gồm:

Tác nhân DQN có thể được sử dụng trong bất kỳ môi trường nào có không gian hoạt động riêng biệt.

Tại trung tâm của một Agent DQN là một QNetwork , một mô hình mạng thần kinh mà có thể học để dự đoán QValues (lợi nhuận dự kiến) cho tất cả các hành động, đưa ra một quan sát từ môi trường.

Chúng tôi sẽ sử dụng tf_agents.networks. để tạo ra một QNetwork . Mạng lưới này sẽ bao gồm một chuỗi các tf.keras.layers.Dense lớp, nơi mà các lớp cuối cùng sẽ có 1 đầu ra cho mỗi hành động càng tốt.

fc_layer_params = (100, 50)
action_tensor_spec = tensor_spec.from_spec(env.action_spec())
num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

# Define a helper function to create Dense layers configured with the right
# activation and kernel initializer.
def dense_layer(num_units):
  return tf.keras.layers.Dense(
      num_units,
      activation=tf.keras.activations.relu,
      kernel_initializer=tf.keras.initializers.VarianceScaling(
          scale=2.0, mode='fan_in', distribution='truncated_normal'))

# QNetwork consists of a sequence of Dense layers followed by a dense layer
# with `num_actions` units to generate one q_value per available action as
# its output.
dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
q_values_layer = tf.keras.layers.Dense(
    num_actions,
    activation=None,
    kernel_initializer=tf.keras.initializers.RandomUniform(
        minval=-0.03, maxval=0.03),
    bias_initializer=tf.keras.initializers.Constant(-0.2))
q_net = sequential.Sequential(dense_layers + [q_values_layer])

Bây giờ sử dụng tf_agents.agents.dqn.dqn_agent để nhanh chóng một DqnAgent . Ngoài các time_step_spec , action_spec và QNetwork, các nhà xây dựng đại lý cũng đòi hỏi một ưu (trong trường hợp này, AdamOptimizer ), một chức năng mất mát, và một bước nguyên quầy.

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

agent.initialize()

Chính sách

Chính sách xác định cách tác nhân hành động trong môi trường. Thông thường, mục tiêu của học tập củng cố là đào tạo mô hình cơ bản cho đến khi chính sách tạo ra kết quả mong muốn.

Trong hướng dẫn này:

  • Kết quả mong muốn là giữ cho cột thăng bằng thẳng đứng trên xe đẩy.
  • Chính sách này sẽ trả về một hành động (trái hoặc phải) cho mỗi time_step quan sát.

Đại lý có hai chính sách:

  • agent.policy - Chính sách chính được sử dụng để đánh giá và triển khai.
  • agent.collect_policy - Một chính sách thứ hai được sử dụng để thu thập dữ liệu.
eval_policy = agent.policy
collect_policy = agent.collect_policy

Các chính sách có thể được tạo độc lập với các đại lý. Ví dụ, sử dụng tf_agents.policies.random_tf_policy tạo một chính sách mà sẽ lựa chọn ngẫu nhiên một hành động cho mỗi time_step .

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

Để có được một hành động từ một chính sách, gọi policy.action(time_step) phương pháp. Các time_step chứa các quan sát từ môi trường. Phương pháp này trả về một PolicyStep , mà là một tuple tên với ba thành phần:

  • action - hành động được thực hiện (trong trường hợp này, 0 hoặc 1 )
  • state - sử dụng cho stateful (có nghĩa là, RNN-based) chính sách
  • info - dữ liệu phụ trợ, chẳng hạn như xác suất log của các hành động
example_environment = tf_py_environment.TFPyEnvironment(
    suite_gym.load('CartPole-v0'))
time_step = example_environment.reset()
random_policy.action(time_step)
PolicyStep(action=<tf.Tensor: shape=(1,), dtype=int64, numpy=array([1])>, state=(), info=())

Số liệu và Đánh giá

Số liệu phổ biến nhất được sử dụng để đánh giá một chính sách là lợi tức trung bình. Lợi tức là tổng phần thưởng nhận được khi chạy chính sách trong môi trường cho một tập. Một số tập được chạy, tạo ra lợi nhuận trung bình.

Hàm sau đây tính toán lợi nhuận trung bình của một chính sách, dựa trên chính sách, môi trường và một số tập.

def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


# See also the metrics module for standard implementations of different metrics.
# https://github.com/tensorflow/agents/tree/master/tf_agents/metrics

Chạy tính toán này trên random_policy cho thấy một hiệu suất cơ bản trong môi trường.

compute_avg_return(eval_env, random_policy, num_eval_episodes)
20.7

Replay Buffer

Để theo dõi các dữ liệu thu thập từ môi trường, chúng tôi sẽ sử dụng Reverb , một hệ thống phát lại hiệu quả, mở rộng, và dễ dàng sử dụng bởi Deepmind. Nó lưu trữ dữ liệu kinh nghiệm khi chúng tôi thu thập quỹ đạo và được tiêu thụ trong quá trình đào tạo.

Bộ đệm phát lại này được xây dựng bằng cách sử dụng các thông số kỹ thuật mô tả các tensors sẽ được lưu trữ, có thể được lấy từ tác nhân bằng cách sử dụng agent.collect_data_spec.

table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(
      agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(
    replay_buffer_signature)

table = reverb.Table(
    table_name,
    max_size=replay_buffer_max_length,
    sampler=reverb.selectors.Uniform(),
    remover=reverb.selectors.Fifo(),
    rate_limiter=reverb.rate_limiters.MinSize(1),
    signature=replay_buffer_signature)

reverb_server = reverb.Server([table])

replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
    agent.collect_data_spec,
    table_name=table_name,
    sequence_length=2,
    local_server=reverb_server)

rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
  replay_buffer.py_client,
  table_name,
  sequence_length=2)
[reverb/cc/platform/tfrecord_checkpointer.cc:150]  Initializing TFRecordCheckpointer in /tmp/tmpcz7e0i7c.
[reverb/cc/platform/tfrecord_checkpointer.cc:385] Loading latest checkpoint from /tmp/tmpcz7e0i7c
[reverb/cc/platform/default/server.cc:71] Started replay server on port 21909

Đối với hầu hết các đại lý, collect_data_spec là một tuple tên gọi là Trajectory , có chứa các thông số kỹ thuật cho các quan sát, hành động, phần thưởng, và các mặt hàng khác.

agent.collect_data_spec
Trajectory(
{'action': BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0), maximum=array(1)),
 'discount': BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)),
 'next_step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type'),
 'observation': BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32)),
 'policy_info': (),
 'reward': TensorSpec(shape=(), dtype=tf.float32, name='reward'),
 'step_type': TensorSpec(shape=(), dtype=tf.int32, name='step_type')})
agent.collect_data_spec._fields
('step_type',
 'observation',
 'action',
 'policy_info',
 'next_step_type',
 'reward',
 'discount')

Thu thập dữ liệu

Bây giờ thực thi chính sách ngẫu nhiên trong môi trường trong một vài bước, ghi lại dữ liệu trong bộ đệm phát lại.

Ở đây chúng tôi đang sử dụng 'PyDriver' để chạy vòng lặp thu thập kinh nghiệm. Bạn có thể tìm hiểu thêm về lái xe TF Đại lý tại chúng tôi lái xe hướng dẫn .

py_driver.PyDriver(
    env,
    py_tf_eager_policy.PyTFEagerPolicy(
      random_policy, use_tf_function=True),
    [rb_observer],
    max_steps=initial_collect_steps).run(train_py_env.reset())
(TimeStep(
 {'discount': array(1., dtype=float32),
  'observation': array([ 0.04100575,  0.16847703, -0.12718087, -0.6300714 ], dtype=float32),
  'reward': array(1., dtype=float32),
  'step_type': array(1, dtype=int32)}),
 ())

Bộ đệm phát lại bây giờ là một tập hợp các Quỹ đạo.

# For the curious:
# Uncomment to peel one of these off and inspect it.
# iter(replay_buffer.as_dataset()).next()

Tác nhân cần quyền truy cập vào bộ đệm phát lại. Này được cung cấp bằng cách tạo ra một iterable tf.data.Dataset đường ống này sẽ cung cấp dữ liệu cho các đại lý.

Mỗi hàng của bộ đệm phát lại chỉ lưu một bước quan sát duy nhất. Nhưng kể từ khi Agent DQN cần cả hai quan sát hiện nay và tiếp theo để tính toán sự mất mát, các đường ống dẫn dữ liệu sẽ lấy mẫu hai hàng liền kề cho mỗi mục trong lô ( num_steps=2 ).

Tập dữ liệu này cũng được tối ưu hóa bằng cách chạy các cuộc gọi song song và tìm nạp trước dữ liệu.

# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3,
    sample_batch_size=batch_size,
    num_steps=2).prefetch(3)

dataset
<PrefetchDataset shapes: (Trajectory(
{action: (64, 2),
 discount: (64, 2),
 next_step_type: (64, 2),
 observation: (64, 2, 4),
 policy_info: (),
 reward: (64, 2),
 step_type: (64, 2)}), SampleInfo(key=(64, 2), probability=(64, 2), table_size=(64, 2), priority=(64, 2))), types: (Trajectory(
{action: tf.int64,
 discount: tf.float32,
 next_step_type: tf.int32,
 observation: tf.float32,
 policy_info: (),
 reward: tf.float32,
 step_type: tf.int32}), SampleInfo(key=tf.uint64, probability=tf.float64, table_size=tf.int64, priority=tf.float64))>
iterator = iter(dataset)
print(iterator)
<tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x7f3cec38cd90>
# For the curious:
# Uncomment to see what the dataset iterator is feeding to the agent.
# Compare this representation of replay data 
# to the collection of individual trajectories shown earlier.

# iterator.next()

Đào tạo đại lý

Hai điều phải xảy ra trong vòng lặp đào tạo:

  • thu thập dữ liệu từ môi trường
  • sử dụng dữ liệu đó để đào tạo (các) mạng thần kinh của tác nhân

Ví dụ này cũng đánh giá định kỳ chính sách và in ra điểm số hiện tại.

Phần sau sẽ mất ~ 5 phút để chạy.

try:
  %%time
except:
  pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step.
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

# Reset the environment.
time_step = train_py_env.reset()

# Create a driver to collect experience.
collect_driver = py_driver.PyDriver(
    env,
    py_tf_eager_policy.PyTFEagerPolicy(
      agent.collect_policy, use_tf_function=True),
    [rb_observer],
    max_steps=collect_steps_per_iteration)

for _ in range(num_iterations):

  # Collect a few steps and save to the replay buffer.
  time_step, _ = collect_driver.run(time_step)

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience).loss

  step = agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    returns.append(avg_return)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:206: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (15446) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (15446) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (15446) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (15446) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (15446) so Table uniform_table is accessed directly without gRPC.
[reverb/cc/client.cc:163] Sampler and server are owned by the same process (15446) so Table uniform_table is accessed directly without gRPC.
step = 200: loss = 27.080341339111328
step = 400: loss = 3.0314550399780273
step = 600: loss = 470.9187927246094
step = 800: loss = 548.7870483398438
step = 1000: loss = 4315.17578125
step = 1000: Average Return = 48.400001525878906
step = 1200: loss = 5297.24853515625
step = 1400: loss = 11601.296875
step = 1600: loss = 60482.578125
step = 1800: loss = 802764.8125
step = 2000: loss = 1689283.0
step = 2000: Average Return = 63.400001525878906
step = 2200: loss = 4928921.0
step = 2400: loss = 5508345.0
step = 2600: loss = 17888162.0
step = 2800: loss = 23993148.0
step = 3000: loss = 10192765.0
step = 3000: Average Return = 74.0999984741211
step = 3200: loss = 88318176.0
step = 3400: loss = 77485728.0
step = 3600: loss = 3236693504.0
step = 3800: loss = 102289840.0
step = 4000: loss = 168594496.0
step = 4000: Average Return = 73.5999984741211
step = 4200: loss = 348990528.0
step = 4400: loss = 101819664.0
step = 4600: loss = 136486208.0
step = 4800: loss = 133454864.0
step = 5000: loss = 592934784.0
step = 5000: Average Return = 71.5999984741211
step = 5200: loss = 216909120.0
step = 5400: loss = 181369648.0
step = 5600: loss = 600455680.0
step = 5800: loss = 551183744.0
step = 6000: loss = 368749824.0
step = 6000: Average Return = 83.5
step = 6200: loss = 1010418176.0
step = 6400: loss = 171257856.0
step = 6600: loss = 115424904.0
step = 6800: loss = 144941152.0
step = 7000: loss = 257932752.0
step = 7000: Average Return = 107.0
step = 7200: loss = 854109248.0
step = 7400: loss = 95970128.0
step = 7600: loss = 325583744.0
step = 7800: loss = 858134016.0
step = 8000: loss = 197960128.0
step = 8000: Average Return = 124.19999694824219
step = 8200: loss = 310187552.0
step = 8400: loss = 572293760.0
step = 8600: loss = 2338323456.0
step = 8800: loss = 384659392.0
step = 9000: loss = 676924544.0
step = 9000: Average Return = 200.0
step = 9200: loss = 946199168.0
step = 9400: loss = 605189504.0
step = 9600: loss = 768988928.0
step = 9800: loss = 508231776.0
step = 10000: loss = 526518016.0
step = 10000: Average Return = 200.0
step = 10200: loss = 1461528704.0
step = 10400: loss = 709822016.0
step = 10600: loss = 2770553344.0
step = 10800: loss = 496421504.0
step = 11000: loss = 1822116864.0
step = 11000: Average Return = 200.0
step = 11200: loss = 744854208.0
step = 11400: loss = 778800384.0
step = 11600: loss = 667049216.0
step = 11800: loss = 586587648.0
step = 12000: loss = 2586833920.0
step = 12000: Average Return = 200.0
step = 12200: loss = 1002041472.0
step = 12400: loss = 1526919552.0
step = 12600: loss = 1670877056.0
step = 12800: loss = 1857608704.0
step = 13000: loss = 1040727936.0
step = 13000: Average Return = 200.0
step = 13200: loss = 1807798656.0
step = 13400: loss = 1457996544.0
step = 13600: loss = 1322671616.0
step = 13800: loss = 22940983296.0
step = 14000: loss = 1556422912.0
step = 14000: Average Return = 200.0
step = 14200: loss = 2488473600.0
step = 14400: loss = 46558289920.0
step = 14600: loss = 1958968960.0
step = 14800: loss = 4677744640.0
step = 15000: loss = 1648418304.0
step = 15000: Average Return = 200.0
step = 15200: loss = 46132723712.0
step = 15400: loss = 2189093888.0
step = 15600: loss = 1204941056.0
step = 15800: loss = 1578462080.0
step = 16000: loss = 1695949312.0
step = 16000: Average Return = 200.0
step = 16200: loss = 19554553856.0
step = 16400: loss = 2857277184.0
step = 16600: loss = 5782225408.0
step = 16800: loss = 2294467072.0
step = 17000: loss = 2397877248.0
step = 17000: Average Return = 200.0
step = 17200: loss = 2910329088.0
step = 17400: loss = 6317301760.0
step = 17600: loss = 2733602048.0
step = 17800: loss = 32502740992.0
step = 18000: loss = 6295858688.0
step = 18000: Average Return = 200.0
step = 18200: loss = 2564860160.0
step = 18400: loss = 76450430976.0
step = 18600: loss = 6347636736.0
step = 18800: loss = 6258629632.0
step = 19000: loss = 8091572224.0
step = 19000: Average Return = 200.0
step = 19200: loss = 3860335616.0
step = 19400: loss = 3552561152.0
step = 19600: loss = 4175943424.0
step = 19800: loss = 5975838720.0
step = 20000: loss = 4709884928.0
step = 20000: Average Return = 200.0

Hình dung

Lô đất

Sử dụng matplotlib.pyplot vào bảng xếp hạng như thế nào chính sách được cải thiện trong thời gian đào tạo.

Một lần lặp của Cartpole-v0 gồm 200 bước thời gian. Môi trường cung cấp cho một phần thưởng của +1 cho mỗi bước ở lại cực lên, vì vậy sự trở lại tối đa cho một tập phim là 200. Các biểu đồ hiển thị sự trở lại ngày càng tăng đối với tối đa mà mỗi lần nó được đánh giá trong đào tạo. (Nó có thể hơi không ổn định và không tăng đơn điệu mỗi lần.)

iterations = range(0, num_iterations + 1, eval_interval)
plt.plot(iterations, returns)
plt.ylabel('Average Return')
plt.xlabel('Iterations')
plt.ylim(top=250)
(40.82000160217285, 250.0)

png

Video

Biểu đồ là tốt đẹp. Nhưng thú vị hơn là thấy một đặc vụ thực sự thực hiện một nhiệm vụ trong một môi trường.

Đầu tiên, hãy tạo một chức năng để nhúng video vào sổ ghi chép.

def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)

Bây giờ hãy lặp lại một vài tập của trò chơi Cartpole với đặc vụ. Môi trường Python cơ bản (một trong những "bên trong" wrapper môi trường TensorFlow) cung cấp một render() phương pháp, mà kết quả đầu ra một hình ảnh của tình trạng môi trường. Chúng có thể được thu thập thành video.

def create_policy_eval_video(policy, filename, num_episodes=5, fps=30):
  filename = filename + ".mp4"
  with imageio.get_writer(filename, fps=fps) as video:
    for _ in range(num_episodes):
      time_step = eval_env.reset()
      video.append_data(eval_py_env.render())
      while not time_step.is_last():
        action_step = policy.action(time_step)
        time_step = eval_env.step(action_step.action)
        video.append_data(eval_py_env.render())
  return embed_mp4(filename)

create_policy_eval_video(agent.policy, "trained-agent")
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.
[swscaler @ 0x55d99fdf83c0] Warning: data is not aligned! This can lead to a speed loss

Để giải trí, hãy so sánh nhân viên được đào tạo (ở trên) với một nhân viên di chuyển ngẫu nhiên. (Nó cũng không hoạt động.)

create_policy_eval_video(random_policy, "random-agent")
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.
[swscaler @ 0x55ffa7fe73c0] Warning: data is not aligned! This can lead to a speed loss