このページは Cloud Translation API によって翻訳されました。
Switch to English

TFエージェントを使用してDeepQネットワークをトレーニングする

TensorFlow.orgで表示 GoogleColabで実行 GitHubでソースを表示ノートブックをダウンロード

前書き

この例は、TF-Agentsライブラリを使用して、カートポール環境でDQN(Deep Q Networks)エージェントをトレーニングする方法を示しています。

カートポール環境

トレーニング、評価、データ収集のための強化学習(RL)パイプラインのすべてのコンポーネントについて説明します。

このコードをライブで実行するには、上の[GoogleColabで実行]リンクをクリックしてください。

セットアップ

次の依存関係をインストールしていない場合は、次を実行します。

sudo apt-get install -y xvfb ffmpeg
pip install -q 'gym==0.10.11'
pip install -q 'imageio==2.4.0'
pip install -q PILLOW
pip install -q 'pyglet==1.3.2'
pip install -q pyvirtualdisplay
pip install -q tf-agents



ffmpeg is already the newest version (7:3.4.8-0ubuntu0.2).
xvfb is already the newest version (2:1.19.6-1ubuntu4.6).
The following packages were automatically installed and are no longer required:
  dconf-gsettings-backend dconf-service dkms freeglut3 freeglut3-dev
  glib-networking glib-networking-common glib-networking-services
  gsettings-desktop-schemas libcairo-gobject2 libcolord2 libdconf1
  libegl1-mesa libepoxy0 libglu1-mesa libglu1-mesa-dev libgtk-3-0
  libgtk-3-common libice-dev libjansson4 libjson-glib-1.0-0
  libjson-glib-1.0-common libproxy1v5 librest-0.7-0 libsm-dev
  libsoup-gnome2.4-1 libsoup2.4-1 libxi-dev libxmu-dev libxmu-headers
  libxnvctrl0 libxt-dev linux-gcp-headers-5.0.0-1026
  linux-headers-5.0.0-1026-gcp linux-image-5.0.0-1026-gcp
  linux-modules-5.0.0-1026-gcp pkg-config policykit-1-gnome python3-xkit
  screen-resolution-extra xserver-xorg-core-hwe-18.04
Use 'sudo apt autoremove' to remove them.
0 upgraded, 0 newly installed, 0 to remove and 87 not upgraded.

from __future__ import absolute_import, division, print_function

import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay

import tensorflow as tf

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
tf.compat.v1.enable_v2_behavior()

# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
tf.version.VERSION
'2.3.1'

ハイパーパラメータ

num_iterations = 20000 # @param {type:"integer"}

initial_collect_steps = 100  # @param {type:"integer"} 
collect_steps_per_iteration = 1  # @param {type:"integer"}
replay_buffer_max_length = 100000  # @param {type:"integer"}

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
log_interval = 200  # @param {type:"integer"}

num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}

環境

強化学習(RL)では、環境は解決すべきタスクまたは問題を表します。標準環境は、 tf_agents.environmentsスイートを使用してTFエージェントで作成できます。 TF-Agentsには、OpenAI Gym、Atari、DMControlなどのソースから環境をロードするためのスイートがあります。

OpenAIGymスイートからCartPole環境をロードします。

env_name = 'CartPole-v0'
env = suite_gym.load(env_name)

この環境をレンダリングして、どのように見えるかを確認できます。カートにはフリースイングポールが付いています。目標は、ポールを上向きに保つためにカートを右または左に動かすことです。


env.reset()
PIL.Image.fromarray(env.render())

png

environment.stepメソッドは、環境内でactionし、環境の次の観測とアクションの報酬を含むTimeStepタプルを返します。

time_step_spec()メソッドは、 TimeStepタプルの仕様を返します。そのobservation属性は、 observationの形状、データ型、および許可される値の範囲を示します。 reward属性には、 rewardの同じ詳細が表示されます。

print('Observation Spec:')
print(env.time_step_spec().observation)
Observation Spec:
BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])

print('Reward Spec:')
print(env.time_step_spec().reward)
Reward Spec:
ArraySpec(shape=(), dtype=dtype('float32'), name='reward')

action_spec()メソッドは、有効なアクションの形状、データ型、および許可された値を返します。

print('Action Spec:')
print(env.action_spec())
Action Spec:
BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)

カートポール環境の場合:

  • observationは4つのフロートの配列です。
    • カートの位置と速度
    • ポールの角度位置と速度
  • rewardはスカラー浮動小数点値です
  • actionは、2つの可能な値のみを持つスカラー整数です。
    • 0 —「左に移動」
    • 1 —「右に移動」
time_step = env.reset()
print('Time step:')
print(time_step)

action = np.array(1, dtype=np.int32)

next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
Time step:
TimeStep(step_type=array(0, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.02260124,  0.01524773,  0.01023087, -0.00098117], dtype=float32))
Next time step:
TimeStep(step_type=array(1, dtype=int32), reward=array(1., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.0229062 ,  0.21022147,  0.01021125, -0.29041865], dtype=float32))

通常、2つの環境がインスタンス化されます。1つはトレーニング用、もう1つは評価用です。

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

Cartpole環境は、ほとんどの環境と同様に、純粋なPythonで記述されています。これは、 TFPyEnvironmentラッパーを使用してTFPyEnvironment変換されます。

元の環境のAPIはNumpy配列を使用します。 TFPyEnvironmentこれらをTensorsに変換して、 Tensorsエージェントおよびポリシーと互換性を持たせます。

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

エージェント

RL問題を解決するために使用されるアルゴリズムは、 Agentによって表されます。 TF-剤の種々の標準的な実装を提供Agents含みます。

DQNエージェントは、個別のアクションスペースがある任意の環境で使用できます。

DQNエージェントの中心となるのはQNetworkです。これは、環境からの観察を前提として、すべてのアクションのQValues (期待収益)を予測することを学習できるニューラルネットワークモデルです。

tf_agents.networks.q_networkを使用してQNetworkを作成し、 observation_specaction_spec 、およびモデルの非表示レイヤーの数とサイズを説明するタプルをaction_specます。

fc_layer_params = (100,)

q_net = q_network.QNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)

次に、 tf_agents.agents.dqn.dqn_agentを使用してtf_agents.agents.dqn.dqn_agentをインスタンス化しDqnAgenttime_step_specaction_spec 、およびQNetworkに加えて、エージェントコンストラクターには、オプティマイザー(この場合はAdamOptimizer )、損失関数、および整数ステップカウンターも必要です。

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

agent.initialize()

ポリシー

ポリシーは、エージェントが環境内で動作する方法を定義します。通常、強化学習の目標は、ポリシーが目的の結果を生成するまで、基礎となるモデルをトレーニングすることです。

このチュートリアルでは:

  • 望ましい結果は、カート上でポールのバランスを直立に保つことです。
  • ポリシーは、 time_step観測ごとにアクション(左または右)を返します。

エージェントには2つのポリシーが含まれています。

  • agent.policy —評価と展開に使用されるメインポリシー。
  • agent.collect_policy —データ収集に使用される2番目のポリシー。
eval_policy = agent.policy
collect_policy = agent.collect_policy

ポリシーは、エージェントとは独立して作成できます。たとえば、 tf_agents.policies.random_tf_policyを使用して、 time_stepごとにアクションをランダムに選択するポリシーを作成します。

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

ポリシーからアクションを取得するには、 policy.action(time_step)メソッドを呼び出します。 time_stepには、環境からの観測が含まれます。このメソッドは、 PolicyStep 3つのコンポーネントを持つ名前付きタプルであるPolicyStep返します。

  • action —実行するアクション(この場合、 0または1
  • state —ステートフル(つまり、RNNベース)ポリシーに使用されます
  • info —アクションのログ確率などの補助データ
example_environment = tf_py_environment.TFPyEnvironment(
    suite_gym.load('CartPole-v0'))
time_step = example_environment.reset()
random_policy.action(time_step)
PolicyStep(action=<tf.Tensor: shape=(1,), dtype=int64, numpy=array([1])>, state=(), info=())

指標と評価

ポリシーの評価に使用される最も一般的な指標は、平均収益です。リターンは、エピソードの環境でポリシーを実行している間に取得した報酬の合計です。いくつかのエピソードが実行され、平均的なリターンが得られます。

次の関数は、ポリシー、環境、およびエピソードの数を指定して、ポリシーの平均リターンを計算します。


def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


# See also the metrics module for standard implementations of different metrics.
# https://github.com/tensorflow/agents/tree/master/tf_agents/metrics

random_policyこの計算を実行すると、環境のベースラインパフォーマンスが示されます。

compute_avg_return(eval_env, random_policy, num_eval_episodes)
21.6

リプレイバッファ

再生バッファは、環境から収集されたデータを追跡します。このチュートリアルでは、最も一般的なtf_agents.replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer使用します。

コンストラクターには、収集するデータの仕様が必要です。これは、 collect_data_specメソッドを使用してエージェントから入手できます。バッチサイズと最大バッファー長も必要です。

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length)

ほとんどのエージェントの場合、 collect_data_specTrajectoryと呼ばれる名前付きタプルであり、監視、アクション、報酬、およびその他のアイテムの仕様が含まれています。

agent.collect_data_spec
Trajectory(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), observation=BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32)), action=BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0), maximum=array(1)), policy_info=(), next_step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)))
agent.collect_data_spec._fields
('step_type',
 'observation',
 'action',
 'policy_info',
 'next_step_type',
 'reward',
 'discount')

データ収集

次に、環境内でランダムポリシーを数ステップ実行し、データを再生バッファーに記録します。


def collect_step(environment, policy, buffer):
  time_step = environment.current_time_step()
  action_step = policy.action(time_step)
  next_time_step = environment.step(action_step.action)
  traj = trajectory.from_transition(time_step, action_step, next_time_step)

  # Add trajectory to the replay buffer
  buffer.add_batch(traj)

def collect_data(env, policy, buffer, steps):
  for _ in range(steps):
    collect_step(env, policy, buffer)

collect_data(train_env, random_policy, replay_buffer, initial_collect_steps)

# This loop is so common in RL, that we provide standard implementations. 
# For more details see the drivers module.
# https://www.tensorflow.org/agents/api_docs/python/tf_agents/drivers

リプレイバッファは、トラジェクトリのコレクションになりました。

# For the curious:
# Uncomment to peel one of these off and inspect it.
# iter(replay_buffer.as_dataset()).next()

エージェントはリプレイバッファにアクセスする必要があります。これは、データをエージェントにフィードする反復可能なtf.data.Datasetパイプラインを作成することによって提供されます。

再生バッファの各行には、単一の観測ステップのみが格納されます。ただし、DQNエージェントは損失を計算するために現在と次の両方の観測を必要とするため、データセットパイプラインは、バッチ内の各アイテムについて2つの隣接する行をサンプリングします( num_steps=2 )。

このデータセットは、並列呼び出しを実行してデータをプリフェッチすることによっても最適化されます。

# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, 
    sample_batch_size=batch_size, 
    num_steps=2).prefetch(3)


dataset
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/autograph/operators/control_flow.py:1004: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.

<PrefetchDataset shapes: (Trajectory(step_type=(64, 2), observation=(64, 2, 4), action=(64, 2), policy_info=(), next_step_type=(64, 2), reward=(64, 2), discount=(64, 2)), BufferInfo(ids=(64, 2), probabilities=(64,))), types: (Trajectory(step_type=tf.int32, observation=tf.float32, action=tf.int64, policy_info=(), next_step_type=tf.int32, reward=tf.float32, discount=tf.float32), BufferInfo(ids=tf.int64, probabilities=tf.float32))>
iterator = iter(dataset)

print(iterator)

<tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x7fed4470afd0>

# For the curious:
# Uncomment to see what the dataset iterator is feeding to the agent.
# Compare this representation of replay data 
# to the collection of individual trajectories shown earlier.

# iterator.next()

エージェントのトレーニング

トレーニングループ中に2つのことが発生する必要があります。

  • 環境からデータを収集する
  • そのデータを使用して、エージェントのニューラルネットワークをトレーニングします

この例では、ポリシーを定期的に評価し、現在のスコアを出力します。

以下の実行には約5分かかります。


try:
  %%time
except:
  pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

  # Collect a few steps using collect_policy and save to the replay buffer.
  collect_data(train_env, agent.collect_policy, replay_buffer, collect_steps_per_iteration)

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience).loss

  step = agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    returns.append(avg_return)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
step = 200: loss = 23.971179962158203
step = 400: loss = 16.704683303833008
step = 600: loss = 4.47831916809082
step = 800: loss = 22.284278869628906
step = 1000: loss = 9.407212257385254
step = 1000: Average Return = 22.399999618530273
step = 1200: loss = 50.882041931152344
step = 1400: loss = 64.36306762695312
step = 1600: loss = 81.06721496582031
step = 1800: loss = 14.682304382324219
step = 2000: loss = 23.132694244384766
step = 2000: Average Return = 37.20000076293945
step = 2200: loss = 11.398908615112305
step = 2400: loss = 15.055550575256348
step = 2600: loss = 14.741910934448242
step = 2800: loss = 15.771312713623047
step = 3000: loss = 17.033199310302734
step = 3000: Average Return = 49.099998474121094
step = 3200: loss = 47.628074645996094
step = 3400: loss = 28.386903762817383
step = 3600: loss = 12.820018768310547
step = 3800: loss = 8.513819694519043
step = 4000: loss = 41.961456298828125
step = 4000: Average Return = 122.69999694824219
step = 4200: loss = 21.830732345581055
step = 4400: loss = 65.9542007446289
step = 4600: loss = 50.89095687866211
step = 4800: loss = 29.666650772094727
step = 5000: loss = 10.545876502990723
step = 5000: Average Return = 114.30000305175781
step = 5200: loss = 10.088138580322266
step = 5400: loss = 22.00992202758789
step = 5600: loss = 72.86345672607422
step = 5800: loss = 20.130895614624023
step = 6000: loss = 7.348139762878418
step = 6000: Average Return = 133.39999389648438
step = 6200: loss = 82.72433471679688
step = 6400: loss = 25.444007873535156
step = 6600: loss = 139.33970642089844
step = 6800: loss = 16.38616180419922
step = 7000: loss = 8.287995338439941
step = 7000: Average Return = 156.5
step = 7200: loss = 5.057644367218018
step = 7400: loss = 50.89534378051758
step = 7600: loss = 27.254642486572266
step = 7800: loss = 52.435184478759766
step = 8000: loss = 6.562076568603516
step = 8000: Average Return = 121.5
step = 8200: loss = 10.63643741607666
step = 8400: loss = 50.55724334716797
step = 8600: loss = 68.72991180419922
step = 8800: loss = 199.4698028564453
step = 9000: loss = 297.74847412109375
step = 9000: Average Return = 197.89999389648438
step = 9200: loss = 22.64231300354004
step = 9400: loss = 11.30867862701416
step = 9600: loss = 13.854089736938477
step = 9800: loss = 80.9999008178711
step = 10000: loss = 385.2828369140625
step = 10000: Average Return = 198.0
step = 10200: loss = 403.8046569824219
step = 10400: loss = 83.03425598144531
step = 10600: loss = 160.1787567138672
step = 10800: loss = 319.04412841796875
step = 11000: loss = 100.59796142578125
step = 11000: Average Return = 200.0
step = 11200: loss = 974.8321533203125
step = 11400: loss = 298.2663269042969
step = 11600: loss = 16.23052978515625
step = 11800: loss = 26.330753326416016
step = 12000: loss = 538.9752807617188
step = 12000: Average Return = 200.0
step = 12200: loss = 32.422210693359375
step = 12400: loss = 230.4241180419922
step = 12600: loss = 327.36578369140625
step = 12800: loss = 221.0552978515625
step = 13000: loss = 36.00126647949219
step = 13000: Average Return = 200.0
step = 13200: loss = 224.20574951171875
step = 13400: loss = 252.47915649414062
step = 13600: loss = 655.738037109375
step = 13800: loss = 662.0720825195312
step = 14000: loss = 56.43755340576172
step = 14000: Average Return = 200.0
step = 14200: loss = 1747.84521484375
step = 14400: loss = 51.76509094238281
step = 14600: loss = 51.12700653076172
step = 14800: loss = 1120.146728515625
step = 15000: loss = 993.9840698242188
step = 15000: Average Return = 200.0
step = 15200: loss = 54.091712951660156
step = 15400: loss = 1946.548095703125
step = 15600: loss = 74.04914093017578
step = 15800: loss = 4848.8544921875
step = 16000: loss = 2509.5
step = 16000: Average Return = 200.0
step = 16200: loss = 280.95819091796875
step = 16400: loss = 74.34515380859375
step = 16600: loss = 71.5146255493164
step = 16800: loss = 65.942138671875
step = 17000: loss = 100.37063598632812
step = 17000: Average Return = 200.0
step = 17200: loss = 83.93265533447266
step = 17400: loss = 832.0823974609375
step = 17600: loss = 901.281982421875
step = 17800: loss = 3035.594970703125
step = 18000: loss = 1165.320068359375
step = 18000: Average Return = 200.0
step = 18200: loss = 87.855712890625
step = 18400: loss = 75.82250213623047
step = 18600: loss = 1363.5899658203125
step = 18800: loss = 68.51466369628906
step = 19000: loss = 101.25462341308594
step = 19000: Average Return = 200.0
step = 19200: loss = 773.2719116210938
step = 19400: loss = 117.0003890991211
step = 19600: loss = 3607.708740234375
step = 19800: loss = 4284.08203125
step = 20000: loss = 151.35198974609375
step = 20000: Average Return = 200.0

視覚化

プロット

matplotlib.pyplotを使用して、トレーニング中にポリシーがどのように改善されたかをグラフ化します。

Cartpole-v0 1回の反復は、200のタイムステップで構成されます。この環境では、ポールが上がったままのステップごとに+1報酬が与えられるため、1つのエピソードの最大リターンは200です。グラフは、トレーニング中に評価されるたびに、その最大に向かってリターンが増加することを示しています。 (少し不安定で、毎回単調に増加しない場合があります。)



iterations = range(0, num_iterations + 1, eval_interval)
plt.plot(iterations, returns)
plt.ylabel('Average Return')
plt.xlabel('Iterations')
plt.ylim(top=250)
(1.0250000000000004, 250.0)

png

ビデオ

チャートはいいです。しかし、もっとエキサイティングなのは、エージェントが実際に環境内でタスクを実行しているのを見ることです。

まず、ノートブックにビデオを埋め込む関数を作成します。

def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)

次に、エージェントと一緒にカートポールゲームのいくつかのエピソードを繰り返します。基盤となるPython環境(TensorFlow環境ラッパーの「内部」にある環境)は、環境状態の画像を出力するrender()メソッドを提供しrender() 。これらはビデオに集めることができます。

def create_policy_eval_video(policy, filename, num_episodes=5, fps=30):
  filename = filename + ".mp4"
  with imageio.get_writer(filename, fps=fps) as video:
    for _ in range(num_episodes):
      time_step = eval_env.reset()
      video.append_data(eval_py_env.render())
      while not time_step.is_last():
        action_step = policy.action(time_step)
        time_step = eval_env.step(action_step.action)
        video.append_data(eval_py_env.render())
  return embed_mp4(filename)




create_policy_eval_video(agent.policy, "trained-agent")
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.

楽しみのために、訓練されたエージェント(上記)をランダムに動くエージェントと比較してください。 (それもしません。)

create_policy_eval_video(random_policy, "random-agent")
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.