Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

Latih Jaringan Q Dalam dengan Agen TF

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHubUnduh buku catatan

pengantar

Contoh ini menunjukkan cara melatih agen DQN (Deep Q Networks) di lingkungan Cartpole menggunakan pustaka TF-Agents.

Lingkungan Cartpole

Ini akan memandu Anda melalui semua komponen dalam pipeline Reinforcement Learning (RL) untuk pelatihan, evaluasi, dan pengumpulan data.

Untuk menjalankan kode ini secara langsung, klik tautan 'Jalankan di Google Colab' di atas.

Mendirikan

Jika Anda belum menginstal dependensi berikut, jalankan:

sudo apt-get install -y xvfb ffmpeg
pip install -q 'imageio==2.4.0'
pip install -q pyvirtualdisplay
pip install -q tf-agents
from __future__ import absolute_import, division, print_function

import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay

import tensorflow as tf

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
tf.compat.v1.enable_v2_behavior()

# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()
tf.version.VERSION
'2.4.0'

Hyperparameter

num_iterations = 20000 # @param {type:"integer"}

initial_collect_steps = 100  # @param {type:"integer"} 
collect_steps_per_iteration = 1  # @param {type:"integer"}
replay_buffer_max_length = 100000  # @param {type:"integer"}

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
log_interval = 200  # @param {type:"integer"}

num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}

Lingkungan Hidup

Dalam Reinforcement Learning (RL), lingkungan merepresentasikan tugas atau masalah yang harus diselesaikan. Lingkungan standar dapat dibuat di TF-Agents menggunakan tf_agents.environments suites. Agen TF memiliki suite untuk memuat lingkungan dari sumber seperti OpenAI Gym, Atari, dan Kontrol DM.

Muat lingkungan CartPole dari suite OpenAI Gym.

env_name = 'CartPole-v0'
env = suite_gym.load(env_name)

Anda dapat membuat lingkungan ini untuk melihat tampilannya. Tiang berayun bebas dipasang ke gerobak. Tujuannya adalah untuk menggerakkan gerobak ke kanan atau ke kiri agar tiang tetap mengarah ke atas.

env.reset()
PIL.Image.fromarray(env.render())
.dll

png

Metode environment.step mengambil action di lingkungan dan mengembalikan tupel TimeStep berisi pengamatan lingkungan berikutnya dan hadiah untuk tindakan tersebut.

Metode time_step_spec() mengembalikan spesifikasi untuk tupel TimeStep . Atribut observation menunjukkan bentuk observasi, tipe data, dan rentang nilai yang diizinkan. Atribut reward menunjukkan detail yang sama untuk reward.

print('Observation Spec:')
print(env.time_step_spec().observation)
Observation Spec:
BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38])

print('Reward Spec:')
print(env.time_step_spec().reward)
Reward Spec:
ArraySpec(shape=(), dtype=dtype('float32'), name='reward')

Metode action_spec() mengembalikan bentuk, tipe data, dan nilai yang diizinkan dari tindakan valid.

print('Action Spec:')
print(env.action_spec())
Action Spec:
BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1)

Di lingkungan Cartpole:

  • observation adalah array 4 pelampung:
    • posisi dan kecepatan gerobak
    • posisi sudut dan kecepatan tiang
  • reward adalah nilai float skalar
  • action adalah integer skalar dengan hanya dua kemungkinan nilai:
    • 0 - "pindah ke kiri"
    • 1 - "bergerak ke kanan"
time_step = env.reset()
print('Time step:')
print(time_step)

action = np.array(1, dtype=np.int32)

next_time_step = env.step(action)
print('Next time step:')
print(next_time_step)
Time step:
TimeStep(step_type=array(0, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([-0.02220688,  0.02916698, -0.03515396, -0.03343702], dtype=float32))
Next time step:
TimeStep(step_type=array(1, dtype=int32), reward=array(1., dtype=float32), discount=array(1., dtype=float32), observation=array([-0.02162354,  0.22477496, -0.03582269, -0.33700085], dtype=float32))

Biasanya dua lingkungan dipakai: satu untuk pelatihan dan satu lagi untuk evaluasi.

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

Lingkungan Cartpole, seperti kebanyakan lingkungan, ditulis dengan Python murni. Ini diubah menjadi TensorFlow menggunakan pembungkus TFPyEnvironment .

API lingkungan asli menggunakan array Numpy. The TFPyEnvironment mualaf ini untuk Tensors untuk membuatnya kompatibel dengan agen dan kebijakan Tensorflow.

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

Agen

Algoritma yang digunakan untuk menyelesaikan masalah RL diwakili oleh Agent . Agen TF menyediakan implementasi standar dari berbagai Agents , termasuk:

Agen DQN dapat digunakan di lingkungan apa pun yang memiliki ruang tindakan diskrit.

Inti dari Agen DQN adalah QNetwork , model jaringan saraf yang dapat belajar memprediksi QValues (pengembalian yang diharapkan) untuk semua tindakan, dengan pengamatan dari lingkungan.

Gunakan tf_agents.networks.q_network untuk membuat QNetwork , dengan meneruskan observation_spec , action_spec , dan tuple yang menjelaskan jumlah dan ukuran lapisan tersembunyi model.

fc_layer_params = (100,)

q_net = q_network.QNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)

Sekarang gunakan tf_agents.agents.dqn.dqn_agent untuk membuat DqnAgent . Selain time_step_spec , action_spec , dan QNetwork, pembuat agen juga memerlukan pengoptimal (dalam hal ini, AdamOptimizer ), fungsi kerugian, dan penghitung langkah integer.

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

agent.initialize()

Kebijakan

Kebijakan menentukan cara agen bertindak di lingkungan. Biasanya, tujuan pembelajaran penguatan adalah untuk melatih model yang mendasari sampai kebijakan menghasilkan hasil yang diinginkan.

Dalam tutorial ini:

  • Hasil yang diinginkan adalah menjaga tiang tetap tegak di atas gerobak.
  • Kebijakan mengembalikan tindakan (kiri atau kanan) untuk setiap pengamatan time_step .

Agen berisi dua kebijakan:

  • agent.policy - Kebijakan utama yang digunakan untuk evaluasi dan penyebaran.
  • agent.collect_policy - Kebijakan kedua yang digunakan untuk pengumpulan data.
eval_policy = agent.policy
collect_policy = agent.collect_policy

Kebijakan dapat dibuat secara independen dari agen. Misalnya, gunakan tf_agents.policies.random_tf_policy untuk membuat kebijakan yang akan memilih tindakan secara acak untuk setiap time_step .

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

Untuk mendapatkan tindakan dari kebijakan, panggil metode policy.action(time_step) . time_step berisi observasi dari lingkungan. Metode ini mengembalikan PolicyStep , yang merupakan tupel bernama dengan tiga komponen:

  • action - tindakan yang akan diambil (dalam hal ini, 0 atau 1 )
  • state - digunakan untuk kebijakan stateful (yaitu, berbasis RNN)
  • info - data tambahan, seperti log kemungkinan tindakan
example_environment = tf_py_environment.TFPyEnvironment(
    suite_gym.load('CartPole-v0'))
time_step = example_environment.reset()
random_policy.action(time_step)
PolicyStep(action=<tf.Tensor: shape=(1,), dtype=int64, numpy=array([1])>, state=(), info=())

Metrik dan Evaluasi

Metrik paling umum yang digunakan untuk mengevaluasi kebijakan adalah pengembalian rata-rata. Imbalannya adalah jumlah reward yang diperoleh saat menjalankan kebijakan di lingkungan untuk suatu episode. Beberapa episode dijalankan, menciptakan pengembalian rata-rata.

Fungsi berikut menghitung pengembalian rata-rata suatu kebijakan, berdasarkan kebijakan, lingkungan, dan sejumlah episode.

def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


# See also the metrics module for standard implementations of different metrics.
# https://github.com/tensorflow/agents/tree/master/tf_agents/metrics

Menjalankan komputasi ini di random_policy menunjukkan performa dasar di lingkungan.

compute_avg_return(eval_env, random_policy, num_eval_episodes)
19.3

Putar Ulang Buffer

Buffer replay melacak data yang dikumpulkan dari lingkungan. Tutorial ini menggunakan tf_agents.replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer , karena ini yang paling umum.

Konstruktor memerlukan spesifikasi untuk data yang akan dikumpulkannya. Ini tersedia dari agen yang menggunakan metode collect_data_spec . Ukuran batch dan panjang buffer maksimum juga diperlukan.

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length)

Untuk sebagian besar agen, collect_data_spec adalah tupel bernama Trajectory , berisi spesifikasi observasi, tindakan, hadiah, dan item lainnya.

agent.collect_data_spec
Trajectory(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), observation=BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38],
      dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38],
      dtype=float32)), action=BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0), maximum=array(1)), policy_info=(), next_step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)))
agent.collect_data_spec._fields
('step_type',
 'observation',
 'action',
 'policy_info',
 'next_step_type',
 'reward',
 'discount')

Pengumpulan data

Sekarang jalankan kebijakan acak di lingkungan untuk beberapa langkah, merekam data di buffer pemutaran ulang.

def collect_step(environment, policy, buffer):
  time_step = environment.current_time_step()
  action_step = policy.action(time_step)
  next_time_step = environment.step(action_step.action)
  traj = trajectory.from_transition(time_step, action_step, next_time_step)

  # Add trajectory to the replay buffer
  buffer.add_batch(traj)

def collect_data(env, policy, buffer, steps):
  for _ in range(steps):
    collect_step(env, policy, buffer)

collect_data(train_env, random_policy, replay_buffer, initial_collect_steps)

# This loop is so common in RL, that we provide standard implementations. 
# For more details see the drivers module.
# https://www.tensorflow.org/agents/api_docs/python/tf_agents/drivers

Buffer replay sekarang menjadi kumpulan Trajectories.

# For the curious:
# Uncomment to peel one of these off and inspect it.
# iter(replay_buffer.as_dataset()).next()

Agen membutuhkan akses ke buffer replay. Ini disediakan dengan membuat pipelinetf.data.Dataset iterable yang akantf.data.Dataset data ke agen.

Setiap baris buffer replay hanya menyimpan satu langkah observasi. Tapi karena DQN Agent membutuhkan observasi saat ini dan berikutnya untuk menghitung kerugian, pipeline kumpulan data akan mengambil sampel dua baris yang berdekatan untuk setiap item dalam batch ( num_steps=2 ).

Dataset ini juga dioptimalkan dengan menjalankan panggilan paralel dan data prefetching.

# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, 
    sample_batch_size=batch_size, 
    num_steps=2).prefetch(3)


dataset
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/autograph/operators/control_flow.py:1218: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.

<PrefetchDataset shapes: (Trajectory(step_type=(64, 2), observation=(64, 2, 4), action=(64, 2), policy_info=(), next_step_type=(64, 2), reward=(64, 2), discount=(64, 2)), BufferInfo(ids=(64, 2), probabilities=(64,))), types: (Trajectory(step_type=tf.int32, observation=tf.float32, action=tf.int64, policy_info=(), next_step_type=tf.int32, reward=tf.float32, discount=tf.float32), BufferInfo(ids=tf.int64, probabilities=tf.float32))>
iterator = iter(dataset)

print(iterator)
<tensorflow.python.data.ops.iterator_ops.OwnedIterator object at 0x7f424ef1e080>

# For the curious:
# Uncomment to see what the dataset iterator is feeding to the agent.
# Compare this representation of replay data 
# to the collection of individual trajectories shown earlier.

# iterator.next()

Melatih agen

Dua hal harus terjadi selama loop pelatihan:

  • mengumpulkan data dari lingkungan
  • gunakan data tersebut untuk melatih jaringan saraf agen

Contoh ini juga mengevaluasi kebijakan secara berkala dan mencetak skor saat ini.

Proses berikut memerlukan waktu ~ 5 menit.

try:
  %%time
except:
  pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

  # Collect a few steps using collect_policy and save to the replay buffer.
  collect_data(train_env, agent.collect_policy, replay_buffer, collect_steps_per_iteration)

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience).loss

  step = agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    returns.append(avg_return)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
step = 200: loss = 12.070486068725586
step = 400: loss = 8.28309440612793
step = 600: loss = 31.025835037231445
step = 800: loss = 13.815164566040039
step = 1000: loss = 8.706375122070312
step = 1000: Average Return = 146.10000610351562
step = 1200: loss = 5.937136650085449
step = 1400: loss = 8.289731979370117
step = 1600: loss = 33.76885223388672
step = 1800: loss = 47.953025817871094
step = 2000: loss = 52.09583282470703
step = 2000: Average Return = 27.799999237060547
step = 2200: loss = 5.141240119934082
step = 2400: loss = 21.2930908203125
step = 2600: loss = 25.088130950927734
step = 2800: loss = 56.888206481933594
step = 3000: loss = 76.79216003417969
step = 3000: Average Return = 21.700000762939453
step = 3200: loss = 26.36425018310547
step = 3400: loss = 4.469114780426025
step = 3600: loss = 34.6283073425293
step = 3800: loss = 14.598098754882812
step = 4000: loss = 17.743749618530273
step = 4000: Average Return = 46.0
step = 4200: loss = 26.917938232421875
step = 4400: loss = 56.554386138916016
step = 4600: loss = 19.315950393676758
step = 4800: loss = 16.675647735595703
step = 5000: loss = 136.85499572753906
step = 5000: Average Return = 126.4000015258789
step = 5200: loss = 35.19789123535156
step = 5400: loss = 88.04693603515625
step = 5600: loss = 70.27599334716797
step = 5800: loss = 72.77832794189453
step = 6000: loss = 57.98759078979492
step = 6000: Average Return = 127.4000015258789
step = 6200: loss = 191.23748779296875
step = 6400: loss = 6.994782447814941
step = 6600: loss = 90.5509033203125
step = 6800: loss = 38.94111251831055
step = 7000: loss = 8.685359001159668
step = 7000: Average Return = 130.5
step = 7200: loss = 128.6968536376953
step = 7400: loss = 82.48645782470703
step = 7600: loss = 44.56972122192383
step = 7800: loss = 55.02344512939453
step = 8000: loss = 141.50399780273438
step = 8000: Average Return = 159.3000030517578
step = 8200: loss = 6.795340061187744
step = 8400: loss = 13.114398956298828
step = 8600: loss = 91.94058990478516
step = 8800: loss = 82.01240539550781
step = 9000: loss = 44.955650329589844
step = 9000: Average Return = 178.8000030517578
step = 9200: loss = 68.20870208740234
step = 9400: loss = 154.3193359375
step = 9600: loss = 137.314453125
step = 9800: loss = 74.73216247558594
step = 10000: loss = 91.0711441040039
step = 10000: Average Return = 184.6999969482422
step = 10200: loss = 85.52545166015625
step = 10400: loss = 8.654191970825195
step = 10600: loss = 22.928178787231445
step = 10800: loss = 20.545730590820312
step = 11000: loss = 271.4393005371094
step = 11000: Average Return = 174.3000030517578
step = 11200: loss = 9.628021240234375
step = 11400: loss = 137.62472534179688
step = 11600: loss = 5.842207908630371
step = 11800: loss = 174.1510772705078
step = 12000: loss = 7.528541564941406
step = 12000: Average Return = 182.1999969482422
step = 12200: loss = 8.867502212524414
step = 12400: loss = 81.00005340576172
step = 12600: loss = 12.920578002929688
step = 12800: loss = 142.42262268066406
step = 13000: loss = 143.52105712890625
step = 13000: Average Return = 155.5
step = 13200: loss = 297.1153259277344
step = 13400: loss = 390.66412353515625
step = 13600: loss = 142.9129638671875
step = 13800: loss = 63.83035659790039
step = 14000: loss = 21.4559268951416
step = 14000: Average Return = 162.8000030517578
step = 14200: loss = 129.82550048828125
step = 14400: loss = 16.236011505126953
step = 14600: loss = 9.619779586791992
step = 14800: loss = 184.36654663085938
step = 15000: loss = 10.463872909545898
step = 15000: Average Return = 183.0
step = 15200: loss = 33.84937286376953
step = 15400: loss = 30.32630157470703
step = 15600: loss = 453.1896667480469
step = 15800: loss = 215.66067504882812
step = 16000: loss = 524.3294677734375
step = 16000: Average Return = 188.5
step = 16200: loss = 25.474245071411133
step = 16400: loss = 493.87481689453125
step = 16600: loss = 44.63227462768555
step = 16800: loss = 22.07838249206543
step = 17000: loss = 1202.003662109375
step = 17000: Average Return = 196.8000030517578
step = 17200: loss = 253.89215087890625
step = 17400: loss = 256.4725646972656
step = 17600: loss = 12.142976760864258
step = 17800: loss = 28.95887565612793
step = 18000: loss = 334.6312561035156
step = 18000: Average Return = 184.8000030517578
step = 18200: loss = 14.26879596710205
step = 18400: loss = 375.1312561035156
step = 18600: loss = 295.83953857421875
step = 18800: loss = 38.81288528442383
step = 19000: loss = 20.399158477783203
step = 19000: Average Return = 182.5
step = 19200: loss = 33.21414566040039
step = 19400: loss = 32.32746124267578
step = 19600: loss = 706.607177734375
step = 19800: loss = 381.3253479003906
step = 20000: loss = 349.1406555175781
step = 20000: Average Return = 178.89999389648438

Visualisasi

Plot

Gunakan matplotlib.pyplot untuk membuat bagan bagaimana kebijakan ditingkatkan selama pelatihan.

Satu iterasi Cartpole-v0 terdiri dari 200 langkah waktu. Lingkungan memberikan hadiah +1 untuk setiap langkah tiang tetap naik, jadi pengembalian maksimum untuk satu episode adalah 200. Bagan menunjukkan pengembalian meningkat ke arah maksimum itu setiap kali dievaluasi selama pelatihan. (Ini mungkin sedikit tidak stabil dan tidak meningkat secara monoton setiap saat.)

iterations = range(0, num_iterations + 1, eval_interval)
plt.plot(iterations, returns)
plt.ylabel('Average Return')
plt.xlabel('Iterations')
plt.ylim(top=250)
(4.125000047683715, 250.0)

png

Video

Grafiknya bagus. Tetapi yang lebih menarik adalah melihat agen benar-benar melakukan tugas di lingkungan.

Pertama, buat fungsi untuk menyematkan video di buku catatan.

def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)

Sekarang lakukan iterasi melalui beberapa episode permainan Cartpole dengan agen. Lingkungan Python yang mendasari (yang "di dalam" pembungkus lingkungan TensorFlow) menyediakan metode render() , yang mengeluarkan gambar status lingkungan. Ini bisa dikumpulkan menjadi video.

def create_policy_eval_video(policy, filename, num_episodes=5, fps=30):
  filename = filename + ".mp4"
  with imageio.get_writer(filename, fps=fps) as video:
    for _ in range(num_episodes):
      time_step = eval_env.reset()
      video.append_data(eval_py_env.render())
      while not time_step.is_last():
        action_step = policy.action(time_step)
        time_step = eval_env.step(action_step.action)
        video.append_data(eval_py_env.render())
  return embed_mp4(filename)




create_policy_eval_video(agent.policy, "trained-agent")
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.

Untuk bersenang-senang, bandingkan agen terlatih (di atas) dengan agen yang bergerak secara acak. (Itu tidak berhasil juga.)

create_policy_eval_video(random_policy, "random-agent")
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.