Tarihi kaydet! Google I / O 18-20 Mayıs'ta geri dönüyor Şimdi kaydolun
Bu sayfa, Cloud Translation API ile çevrilmiştir.
Switch to English

DQN C51 / Gökkuşağı

TensorFlow.org'da görüntüleyin Google Colab'da çalıştırın Kaynağı GitHub'da görüntüleyinDefteri indirin

Giriş

Bu örnek, TF-Aracılar kitaplığını kullanarak Kartpole ortamında bir Kategorik DQN (C51) aracısının nasıl eğitileceğini gösterir.

Kart direği ortamı

Ön koşul olarak DQN eğitimine göz attığınızdan emin olun. Bu öğretici, DQN öğreticisine aşina olduğunuzu varsayacaktır; esas olarak DQN ve C51 arasındaki farklara odaklanacaktır.

Kurulum

Henüz tf-agent'ları kurmadıysanız, şunu çalıştırın:

sudo apt-get install -y xvfb ffmpeg
pip install -q 'imageio==2.4.0'
pip install -q pyvirtualdisplay
pip install -q tf-agents
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import PIL.Image
import pyvirtualdisplay

import tensorflow as tf

from tf_agents.agents.categorical_dqn import categorical_dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import categorical_q_network
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

tf.compat.v1.enable_v2_behavior()


# Set up a virtual display for rendering OpenAI gym environments.
display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()

Hiperparametreler

env_name = "CartPole-v1" # @param {type:"string"}
num_iterations = 15000 # @param {type:"integer"}

initial_collect_steps = 1000  # @param {type:"integer"} 
collect_steps_per_iteration = 1  # @param {type:"integer"}
replay_buffer_capacity = 100000  # @param {type:"integer"}

fc_layer_params = (100,)

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
gamma = 0.99
log_interval = 200  # @param {type:"integer"}

num_atoms = 51  # @param {type:"integer"}
min_q_value = -20  # @param {type:"integer"}
max_q_value = 20  # @param {type:"integer"}
n_step_update = 2  # @param {type:"integer"}

num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}

Çevre

Biri eğitim ve diğeri değerlendirme için olmak üzere ortamı eskisi gibi yükleyin. Burada, maksimum ödülü 200 yerine 500 olan CartPole-v1'i (DQN eğitimindeki CartPole-v0'a kıyasla) kullanıyoruz.

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

Ajan

C51, DQN'ye dayalı bir Q-öğrenme algoritmasıdır. DQN gibi, ayrı bir eylem alanı olan herhangi bir ortamda kullanılabilir.

C51 ve DQN arasındaki temel fark, her bir durum-eylem çifti için Q değerini tahmin etmek yerine, C51, Q değerinin olasılık dağılımı için bir histogram modeli öngörmesidir:

Örnek C51 Dağılımı

Algoritma, basitçe beklenen değerden ziyade dağıtımı öğrenerek, eğitim sırasında daha kararlı kalarak daha iyi nihai performans sağlayabilir. Bu, özellikle tek bir ortalamanın doğru bir resim sağlamadığı iki modlu veya hatta çok modlu değer dağılımlarının olduğu durumlarda geçerlidir.

Değerler yerine olasılık dağılımları üzerinde eğitim almak için, C51'in kayıp fonksiyonunu hesaplamak için bazı karmaşık dağılım hesaplamaları yapması gerekir. Ancak endişelenmeyin, tüm bunlar sizin için TF-Agent'larda halledilir!

Bir C51 Aracısı oluşturmak için önce bir CategoricalQNetwork oluşturmamız gerekir. API CategoricalQNetwork ile aynı olan QNetwork , bir argüman olduğunu hariç num_atoms . Bu, olasılık dağılımı tahminlerimizdeki destek noktalarının sayısını temsil eder. (Yukarıdaki görüntü, her biri dikey bir mavi çubukla temsil edilen 10 destek noktası içerir.) Adından da anlayabileceğiniz gibi, varsayılan atom sayısı 51'dir.

categorical_q_net = categorical_q_network.CategoricalQNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    num_atoms=num_atoms,
    fc_layer_params=fc_layer_params)

Ayrıca az önce oluşturduğumuz ağı eğitmek için bir optimizer train_step_counter ve ağın kaç kez güncellendiğini takip etmek için bir train_step_counter değişkenine ihtiyacımız var.

Vanilya DqnAgent bir diğer önemli farkın, artık min_q_value ve max_q_value bağımsız değişken olarak belirtmemiz min_q_value . Bunlar, desteğin en uç değerlerini (başka bir deyişle, her iki taraftaki 51 atomun en uç noktası) belirtir. Bunları kendi ortamınız için uygun şekilde seçtiğinizden emin olun. Burada -20 ve 20 kullanıyoruz.

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.compat.v2.Variable(0)

agent = categorical_dqn_agent.CategoricalDqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    categorical_q_network=categorical_q_net,
    optimizer=optimizer,
    min_q_value=min_q_value,
    max_q_value=max_q_value,
    n_step_update=n_step_update,
    td_errors_loss_fn=common.element_wise_squared_loss,
    gamma=gamma,
    train_step_counter=train_step_counter)
agent.initialize()

Unutulmaması gereken son bir nokta da $ n $ = 2 ile n adımlı güncellemeleri kullanmak için bir argüman eklememizdir. Tek adımlı Q öğrenmede ($ n $ = 1), yalnızca Q değerleri arasındaki hatayı hesaplıyoruz geçerli zaman adımında ve tek adımlı geri dönüşü kullanarak bir sonraki adım adımında (Bellman optimallik denklemine göre). Tek adımlı dönüş şu şekilde tanımlanır:

$ G_t = R_ {t + 1} + \ gamma V (s_ {t + 1}) $

burada $ V (s) = \ max_a {Q (s, a)} $ 'ı tanımlarız.

N adımlı güncellemeler, standart tek adımlı dönüş işlevinin $ n $ kez genişletilmesini içerir:

$ G_t ^ n = R_ {t + 1} + \ gamma R_ {t + 2} + \ gamma ^ 2 R_ {t + 3} + \ dots + \ gamma ^ n V (s_ {t + n}) $

N adımlı güncellemeler, aracının gelecekte daha ileride önyükleme yapmasını sağlar ve n $ 'lık doğru değerle bu genellikle daha hızlı öğrenmeye yol açar.

C51 ve n adımlı güncellemeler, Rainbow aracısının çekirdeğini oluşturmak için genellikle öncelikli tekrarla birleştirilse de, öncelikli yeniden oynatmanın uygulanmasında ölçülebilir bir gelişme görmedik. Dahası, C51 aracımızı tek başına n adımlı güncellemelerle birleştirdiğimizde, temsilcimizin test ettiğimiz Atari ortamları örneğinde diğer Rainbow aracıları kadar iyi performans gösterdiğini gördük.

Ölçümler ve Değerlendirme

Bir politikayı değerlendirmek için kullanılan en yaygın ölçüm ortalama getiridir. Geri dönüş, bir bölüm için bir ortamda bir politika yürütürken elde edilen ödüllerin toplamıdır ve genellikle bunun birkaç bölüm boyunca ortalamasını alırız. Ortalama getiri metriğini aşağıdaki gibi hesaplayabiliriz.

def compute_avg_return(environment, policy, num_episodes=10):

  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]


random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())

compute_avg_return(eval_env, random_policy, num_eval_episodes)

# Please also see the metrics module for standard implementations of different
# metrics.
28.7

Veri toplama

DQN eğitiminde olduğu gibi, yeniden oynatma arabelleğini ve rastgele ilkeyle ilk veri toplamayı ayarlayın.

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_capacity)

def collect_step(environment, policy):
  time_step = environment.current_time_step()
  action_step = policy.action(time_step)
  next_time_step = environment.step(action_step.action)
  traj = trajectory.from_transition(time_step, action_step, next_time_step)

  # Add trajectory to the replay buffer
  replay_buffer.add_batch(traj)

for _ in range(initial_collect_steps):
  collect_step(train_env, random_policy)

# This loop is so common in RL, that we provide standard implementations of
# these. For more details see the drivers module.

# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, sample_batch_size=batch_size,
    num_steps=n_step_update + 1).prefetch(3)

iterator = iter(dataset)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/autograph/operators/control_flow.py:1218: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.

Temsilcinin eğitimi

Eğitim döngüsü, hem ortamdan veri toplamayı hem de aracının ağlarını optimize etmeyi içerir. Yol boyunca, ne durumda olduğumuzu görmek için zaman zaman temsilcinin politikasını değerlendireceğiz.

Aşağıdakilerin çalışması ~ 7 dakika sürecektir.

try:
  %%time
except:
  pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

  # Collect a few steps using collect_policy and save to the replay buffer.
  for _ in range(collect_steps_per_iteration):
    collect_step(train_env, agent.collect_policy)

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience)

  step = agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss.loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1:.2f}'.format(step, avg_return))
    returns.append(avg_return)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
step = 200: loss = 3.321129560470581
step = 400: loss = 2.485752820968628
step = 600: loss = 2.0748603343963623
step = 800: loss = 1.899770736694336
step = 1000: loss = 1.9147026538848877
step = 1000: Average Return = 69.10
step = 1200: loss = 1.470450758934021
step = 1400: loss = 1.524451494216919
step = 1600: loss = 1.3602908849716187
step = 1800: loss = 1.3945512771606445
step = 2000: loss = 1.2128956317901611
step = 2000: Average Return = 201.20
step = 2200: loss = 1.2250053882598877
step = 2400: loss = 1.0739798545837402
step = 2600: loss = 1.0344221591949463
step = 2800: loss = 0.9437637329101562
step = 3000: loss = 1.0215129852294922
step = 3000: Average Return = 142.70
step = 3200: loss = 1.0233310461044312
step = 3400: loss = 0.8907231688499451
step = 3600: loss = 0.7526266574859619
step = 3800: loss = 0.6926383972167969
step = 4000: loss = 0.7934644222259521
step = 4000: Average Return = 476.80
step = 4200: loss = 0.791626513004303
step = 4400: loss = 0.8220507502555847
step = 4600: loss = 0.7975851893424988
step = 4800: loss = 0.4139212369918823
step = 5000: loss = 0.7318903207778931
step = 5000: Average Return = 310.40
step = 5200: loss = 0.7830334305763245
step = 5400: loss = 0.7445043921470642
step = 5600: loss = 0.6130998134613037
step = 5800: loss = 0.5654287338256836
step = 6000: loss = 0.6499170064926147
step = 6000: Average Return = 498.00
step = 6200: loss = 0.6856206655502319
step = 6400: loss = 0.613524317741394
step = 6600: loss = 0.5312545299530029
step = 6800: loss = 0.5998117923736572
step = 7000: loss = 0.35336682200431824
step = 7000: Average Return = 419.60
step = 7200: loss = 0.37572816014289856
step = 7400: loss = 0.3268156051635742
step = 7600: loss = 0.3964875340461731
step = 7800: loss = 0.4353790283203125
step = 8000: loss = 0.47257936000823975
step = 8000: Average Return = 209.20
step = 8200: loss = 0.41818156838417053
step = 8400: loss = 0.295656681060791
step = 8600: loss = 0.30348891019821167
step = 8800: loss = 0.2654055655002594
step = 9000: loss = 0.4846675992012024
step = 9000: Average Return = 431.30
step = 9200: loss = 0.281438410282135
step = 9400: loss = 0.23425081372261047
step = 9600: loss = 0.6559126377105713
step = 9800: loss = 0.4217219948768616
step = 10000: loss = 0.3250614404678345
step = 10000: Average Return = 283.80
step = 10200: loss = 0.2797137498855591
step = 10400: loss = 0.3637545108795166
step = 10600: loss = 0.2684471011161804
step = 10800: loss = 0.45216208696365356
step = 11000: loss = 0.26978206634521484
step = 11000: Average Return = 432.80
step = 11200: loss = 0.41701459884643555
step = 11400: loss = 0.39164310693740845
step = 11600: loss = 0.48381370306015015
step = 11800: loss = 0.3856581449508667
step = 12000: loss = 0.2671810984611511
step = 12000: Average Return = 412.60
step = 12200: loss = 0.37253132462501526
step = 12400: loss = 0.24322597682476044
step = 12600: loss = 0.48967045545578003
step = 12800: loss = 0.3843742907047272
step = 13000: loss = 0.3109121024608612
step = 13000: Average Return = 441.30
step = 13200: loss = 0.32548320293426514
step = 13400: loss = 0.3387058675289154
step = 13600: loss = 0.3758728504180908
step = 13800: loss = 0.2936052680015564
step = 14000: loss = 0.35974568128585815
step = 14000: Average Return = 427.80
step = 14200: loss = 0.3430924713611603
step = 14400: loss = 0.49261224269866943
step = 14600: loss = 0.39563947916030884
step = 14800: loss = 0.3216741681098938
step = 15000: loss = 0.3640541434288025
step = 15000: Average Return = 432.10

Görselleştirme

Arsalar

Temsilcimizin performansını görmek için geri dönüş ve küresel adımlar arasında bir grafik çizebiliriz. Cartpole-v1 ortam, direk yukarıda kaldığı her adım için +1 ödül verir ve maksimum adım sayısı 500 olduğundan, mümkün olan maksimum getiri de 500'dür.

steps = range(0, num_iterations + 1, eval_interval)
plt.plot(steps, returns)
plt.ylabel('Average Return')
plt.xlabel('Step')
plt.ylim(top=550)
(-15.555000400543214, 550.0)

png

VİDEOLAR

Her adımda ortamı oluşturarak bir aracının performansını görselleştirmek faydalıdır. Bunu yapmadan önce, videoları bu araştırmaya yerleştirmek için bir işlev oluşturalım.

def embed_mp4(filename):
  """Embeds an mp4 file in the notebook."""
  video = open(filename,'rb').read()
  b64 = base64.b64encode(video)
  tag = '''
  <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
  Your browser does not support the video tag.
  </video>'''.format(b64.decode())

  return IPython.display.HTML(tag)

Aşağıdaki kod, temsilcinin birkaç bölüm için politikasını görselleştirir:

num_episodes = 3
video_filename = 'imageio.mp4'
with imageio.get_writer(video_filename, fps=60) as video:
  for _ in range(num_episodes):
    time_step = eval_env.reset()
    video.append_data(eval_py_env.render())
    while not time_step.is_last():
      action_step = agent.policy.action(time_step)
      time_step = eval_env.step(action_step.action)
      video.append_data(eval_py_env.render())

embed_mp4(video_filename)
WARNING:root:IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (400, 600) to (400, 608) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to None (risking incompatibility). You may also see a FFMPEG warning concerning speedloss due to data not being aligned.

C51, CartPole-v1'de DQN'den biraz daha iyi olma eğilimindedir, ancak iki ajan arasındaki fark, gittikçe karmaşıklaşan ortamlarda gittikçe daha önemli hale gelmektedir. Örneğin, tam Atari 2600 kıyaslamasında, C51, rastgele bir ajana göre normalleştirildikten sonra DQN'ye göre% 126'lık bir ortalama skor artışı gösterir. N adımlı güncellemeler dahil edilerek ek iyileştirmeler elde edilebilir.

C51 algoritmasına daha derin bir bakış için, Güçlendirmeli Öğrenmeye Dağıtımsal Bakış Açısı (2017) bölümüne bakın.