Контрольный указатель и PolicySaver

Посмотреть на TensorFlow.org Запускаем в Google Colab Посмотреть исходный код на GitHub Скачать блокнот

Введение

tf_agents.utils.common.Checkpointer это утилита для сохранения / загрузки учебного состояния, государственной политики, а также состояние replay_buffer в / из локального хранилища.

tf_agents.policies.policy_saver.PolicySaver является инструментом для сохранения / загрузки только политику, и легче , чем Checkpointer . Вы можете использовать PolicySaver для развертывания модели, а без знания кода , который создал политику.

В этом уроке мы будем использовать DQN для обучения модели, а затем использовать Checkpointer и PolicySaver , чтобы показать , как мы можем сохранять и загружать состояния и модели в интерактивном режиме. Обратите внимание , что мы будем использовать новый saved_model инструмент и формат TF2.0 в для PolicySaver .

Настраивать

Если вы не установили следующие зависимости, запустите:

sudo apt-get update
sudo apt-get install -y xvfb ffmpeg python-opengl
pip install pyglet
pip install 'imageio==2.4.0'
pip install 'xvfbwrapper==0.2.9'
pip install tf-agents[reverb]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython

try:
  from google.colab import files
except ImportError:
  files = None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

tempdir = os.getenv("TEST_TMPDIR", tempfile.gettempdir())
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper.Xvfb(1400, 900, 24).start()

Агент DQN

Мы собираемся настроить агент DQN, как и в предыдущем колабе. Детали по умолчанию скрыты, так как они не являются основной частью этого колаба, но вы можете нажать «ПОКАЗАТЬ КОД», чтобы увидеть детали.

Гиперпараметры

env_name = "CartPole-v1"

collect_steps_per_iteration = 100
replay_buffer_capacity = 100000

fc_layer_params = (100,)

batch_size = 64
learning_rate = 1e-3
log_interval = 5

num_eval_episodes = 10
eval_interval = 1000

Окружающая обстановка

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

Агент

Сбор данных

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py:383: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.

Обучить агента

Генерация видео

Создать видео

Проверьте работоспособность политики, сгенерировав видео.

print ('global_step:')
print (global_step)
run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)
global_step:
<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>

гифка

Настройка Checkpointer и PolicySaver

Теперь мы готовы использовать Checkpointer и PolicySaver.

Контрольный указатель

checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
    ckpt_dir=checkpoint_dir,
    max_to_keep=1,
    agent=agent,
    policy=agent.policy,
    replay_buffer=replay_buffer,
    global_step=global_step
)

Политика экономии

policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
2022-01-20 12:15:14.054931: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.

Обучить одну итерацию

print('Training one iteration....')
train_one_iteration()
Training one iteration....
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
iteration: 1 loss: 1.0214563608169556

Сохранить на контрольную точку

train_checkpointer.save(global_step)

Восстановить контрольную точку

Чтобы это сработало, весь набор объектов должен быть воссоздан так же, как при создании контрольной точки.

train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

Также сохраните политику и экспортируйте в место

tf_policy_saver.save(policy_dir)
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
  "imported and registered." % type_spec_class_name)
INFO:tensorflow:Assets written to: /tmp/policy/assets

Политику можно загрузить, не зная, какой агент или сеть использовались для ее создания. Это значительно упрощает развертывание политики.

Загрузите сохраненную политику и проверьте, как она работает

saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

гифка

Экспорт и импорт

Остальная часть колаба поможет вам экспортировать / импортировать каталоги контрольных точек и политик, чтобы вы могли продолжить обучение позже и развернуть модель без необходимости повторного обучения.

Теперь вы можете вернуться к «Обучить одну итерацию» и потренироваться еще несколько раз, чтобы позже вы могли понять разницу. Как только вы начнете видеть немного лучшие результаты, продолжайте ниже.

Создайте zip-файл и загрузите zip-файл (дважды щелкните, чтобы увидеть код)

Создайте заархивированный файл из каталога контрольной точки.

train_checkpointer.save(global_step)
checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))

Загрузите zip-файл.

if files is not None:
  files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

После некоторого обучения (10-15 раз) загрузите zip-файл контрольной точки и перейдите в «Runtime> Restart and run all», чтобы сбросить обучение и вернуться в эту ячейку. Теперь вы можете загрузить скачанный zip-файл и продолжить обучение.

upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

После того, как вы загрузили каталог контрольных точек, вернитесь к «Обучить одну итерацию», чтобы продолжить обучение, или вернитесь к «Сгенерировать видео», чтобы проверить производительность загруженной политики.

Как вариант, вы можете сохранить политику (модель) и восстановить ее. В отличие от контрольной точки, вы не можете продолжить обучение, но вы все равно можете развернуть модель. Обратите внимание, что загруженный файл намного меньше, чем у контрольной точки.

tf_policy_saver.save(policy_dir)
policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
  "imported and registered." % type_spec_class_name)
INFO:tensorflow:Assets written to: /tmp/policy/assets
if files is not None:
  files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

Загрузите каталог загруженной политики (exported_policy.zip) и проверьте, как работает сохраненная политика.

upload_and_unzip_file_to(policy_dir)
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

гифка

SavedModelPyTFEagerPolicy

Если вы не хотите использовать политику TF, то вы можете также использовать saved_model непосредственно с окр Python посредством использования py_tf_eager_policy.SavedModelPyTFEagerPolicy .

Обратите внимание, что это работает, только когда включен активный режим.

eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
    policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())

# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)

гифка

Преобразовать политику в TFLite

Смотрите конвертер TensorFlow Lite для получения более подробной информации.

converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=["action"])
tflite_policy = converter.convert()
with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:
  f.write(tflite_policy)
2022-01-20 12:15:59.646042: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2022-01-20 12:15:59.646082: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2022-01-20 12:15:59.646088: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded

Выполнить вывод на модели TFLite

См TensorFlow Lite Умозаключение для более подробной информации.

import numpy as np
interpreter = tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))

policy_runner = interpreter.get_signature_runner()
print(policy_runner._inputs)
{'0/discount': 1, '0/observation': 2, '0/reward': 3, '0/step_type': 0}
policy_runner(**{
    '0/discount':tf.constant(0.0),
    '0/observation':tf.zeros([1,4]),
    '0/reward':tf.constant(0.0),
    '0/step_type':tf.constant(0)})
{'action': array([0])}