Tarihi kaydet! Google I / O 18-20 Mayıs'ta geri dönüyor Şimdi kaydolun
Bu sayfa, Cloud Translation API ile çevrilmiştir.
Switch to English

Checkpointer ve PolicySaver

TensorFlow.org'da görüntüleyin Google Colab'de çalıştırın Kaynağı GitHub'da görüntüleyin Defteri indirin

Giriş

tf_agents.utils.common.Checkpointer , eğitim durumunu, ilke durumunu ve replay_buffer durumunu yerel bir depolamaya / bellekten kaydetmek / yüklemek için bir yardımcı programdır.

tf_agents.policies.policy_saver.PolicySaver , yalnızca politikayı kaydetmek / yüklemek için bir araçtır ve Checkpointer daha hafiftir. PolicySaver oluşturan kod hakkında herhangi bir bilgi sahibi olmadan da modeli dağıtmak için PolicySaver kullanabilirsiniz.

Bu öğreticide, bir modeli eğitmek için DQN kullanacağız, ardından durumları ve modeli etkileşimli bir şekilde nasıl saklayıp yükleyebileceğimizi göstermek için Checkpointer ve PolicySaver kullanacağız. TF2.0'ın yeni Saved_model araçlarını ve PolicySaver biçimini PolicySaver .

Kurulum

Aşağıdaki bağımlılıkları yüklemediyseniz, çalıştırın:

sudo apt-get install -y xvfb ffmpeg
pip install -q 'imageio==2.4.0'
pip install -q 'xvfbwrapper==0.2.9'
pip install -q tf-agents
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython

try:
  from google.colab import files
except ImportError:
  files = None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

tf.compat.v1.enable_v2_behavior()

tempdir = os.getenv("TEST_TMPDIR", tempfile.gettempdir())
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper.Xvfb(1400, 900, 24).start()

DQN aracısı

Bir önceki colab'da olduğu gibi DQN ajanı kuracağız. Ayrıntılar, bu sütunun temel parçası olmadıklarından varsayılan olarak gizlidir, ancak ayrıntıları görmek için "KODU GÖSTER" e tıklayabilirsiniz.

Hiperparametreler

env_name = "CartPole-v1"

collect_steps_per_iteration = 100
replay_buffer_capacity = 100000

fc_layer_params = (100,)

batch_size = 64
learning_rate = 1e-3
log_interval = 5

num_eval_episodes = 10
eval_interval = 1000

Çevre

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

Ajan

q_net = q_network.QNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

global_step = tf.compat.v1.train.get_or_create_global_step()

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=global_step)
agent.initialize()

Veri toplama

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_capacity)

collect_driver = dynamic_step_driver.DynamicStepDriver(
    train_env,
    agent.collect_policy,
    observers=[replay_buffer.add_batch],
    num_steps=collect_steps_per_iteration)

# Initial data collection
collect_driver.run()

# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, sample_batch_size=batch_size,
    num_steps=2).prefetch(3)

iterator = iter(dataset)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tf_agents/drivers/dynamic_step_driver.py:203: calling while_loop_v2 (from tensorflow.python.ops.control_flow_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/autograph/operators/control_flow.py:1218: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.

Temsilciyi eğitin

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

def train_one_iteration():

  # Collect a few steps using collect_policy and save to the replay buffer.
  collect_driver.run()

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience)

  iteration = agent.train_step_counter.numpy()
  print ('iteration: {0} loss: {1}'.format(iteration, train_loss.loss))

Video Üretimi

def embed_gif(gif_buffer):
  """Embeds a gif file in the notebook."""
  tag = '<img src="data:image/gif;base64,{0}"/>'.format(base64.b64encode(gif_buffer).decode())
  return IPython.display.HTML(tag)

def run_episodes_and_create_video(policy, eval_tf_env, eval_py_env):
  num_episodes = 3
  frames = []
  for _ in range(num_episodes):
    time_step = eval_tf_env.reset()
    frames.append(eval_py_env.render())
    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = eval_tf_env.step(action_step.action)
      frames.append(eval_py_env.render())
  gif_file = io.BytesIO()
  imageio.mimsave(gif_file, frames, format='gif', fps=60)
  IPython.display.display(embed_gif(gif_file.getvalue()))

Bir video oluşturun

Bir video oluşturarak politikanın performansını kontrol edin.

print ('global_step:')
print (global_step)
run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)
global_step:
<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>

gif

Checkpointer ve PolicySaver Kurulumu

Artık Checkpointer ve PolicySaver'ı kullanmaya hazırız.

Kontrol noktası

checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
    ckpt_dir=checkpoint_dir,
    max_to_keep=1,
    agent=agent,
    policy=agent.policy,
    replay_buffer=replay_buffer,
    global_step=global_step
)

Politika Tasarrufu

policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(agent.policy)

Bir yineleme eğitin

print('Training one iteration....')
train_one_iteration()
Training one iteration....
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py:201: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
iteration: 1 loss: 0.9551953077316284

Kontrol noktasına kaydet

train_checkpointer.save(global_step)

Denetim noktasını geri yükle

Bunun çalışması için, tüm nesne kümesi, kontrol noktası oluşturulduğu zamanki gibi yeniden oluşturulmalıdır.

train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

Ayrıca politikayı kaydedin ve bir konuma aktarın

tf_policy_saver.save(policy_dir)
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
INFO:tensorflow:Assets written to: /tmp/policy/assets

Politika, onu oluşturmak için hangi aracının veya ağın kullanıldığına dair herhangi bir bilgi sahibi olmadan yüklenebilir. Bu, politikanın uygulanmasını çok daha kolay hale getirir.

Kaydedilen politikayı yükleyin ve nasıl çalıştığını kontrol edin

saved_policy = tf.compat.v2.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

gif

Ihracat ve ithalat

Kolabın geri kalanı, kontrol noktası ve ilke dizinlerini dışa / içe aktarmanıza yardımcı olur, böylece eğitime daha sonraki bir noktada devam edebilir ve modeli yeniden eğitmek zorunda kalmadan dağıtabilirsiniz.

Şimdi 'Bir yinelemeyi eğit'e geri dönebilir ve farkı daha sonra anlayabilmek için birkaç kez daha eğitebilirsiniz. Biraz daha iyi sonuçlar görmeye başladıktan sonra aşağıdan devam edin.

Zip dosyası oluşturun ve zip dosyası yükleyin (kodu görmek için çift tıklayın)

def create_zip_file(dirname, base_filename):
  return shutil.make_archive(base_filename, 'zip', dirname)

def upload_and_unzip_file_to(dirname):
  if files is None:
    return
  uploaded = files.upload()
  for fn in uploaded.keys():
    print('User uploaded file "{name}" with length {length} bytes'.format(
        name=fn, length=len(uploaded[fn])))
    shutil.rmtree(dirname)
    zip_files = zipfile.ZipFile(io.BytesIO(uploaded[fn]), 'r')
    zip_files.extractall(dirname)
    zip_files.close()

Kontrol noktası dizininden sıkıştırılmış bir dosya oluşturun.

train_checkpointer.save(global_step)
checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))

Zip dosyasını indirin.

if files is not None:
  files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

Bir süre eğitimden sonra (10-15 kez), kontrol noktası zip dosyasını indirin ve eğitimi sıfırlamak için "Çalışma Zamanı> Yeniden Başlat ve Tümünü Çalıştır" seçeneğine gidin ve bu hücreye geri dönün. Artık indirilen zip dosyasını yükleyebilir ve eğitime devam edebilirsiniz.

upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

Kontrol noktası dizinini yükledikten sonra, eğitime devam etmek için 'Bir yineleme eğit'e geri dönün veya yüklenen poliicy'nin performansını kontrol etmek için' Bir video oluştur 'bölümüne geri dönün.

Alternatif olarak, politikayı (model) kaydedebilir ve geri yükleyebilirsiniz. Kontrol noktasının aksine, eğitime devam edemezsiniz, ancak modeli yine de dağıtabilirsiniz. İndirilen dosyanın kontrol işaretininkinden çok daha küçük olduğuna dikkat edin.

tf_policy_saver.save(policy_dir)
policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
INFO:tensorflow:Assets written to: /tmp/policy/assets
if files is not None:
  files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

İndirilen politika dizinini (exported_policy.zip) yükleyin ve kaydedilen politikanın nasıl performans gösterdiğini kontrol edin.

upload_and_unzip_file_to(policy_dir)
saved_policy = tf.compat.v2.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

gif

SavedModelPyTFEagerPolicy

Eğer TF ilkesini kullanmak istemiyorsanız, o zaman da kullanılması yoluyla Python env doğrudan saved_model kullanabilirsiniz py_tf_eager_policy.SavedModelPyTFEagerPolicy .

Bunun yalnızca istekli mod etkinleştirildiğinde çalıştığını unutmayın.

eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
    policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())

# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)

gif