Checkpointer e PolicySaver

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Introdução

tf_agents.utils.common.Checkpointer é um utilitário para salvar / carregar o estado de treinamento, estado da política e do estado replay_buffer de / para um dispositivo de armazenamento local.

tf_agents.policies.policy_saver.PolicySaver é uma ferramenta para salvar / carregar somente a política, e é mais leve do que Checkpointer . Você pode usar PolicySaver para implantar o modelo bem sem nenhum conhecimento do código que criou a política.

Neste tutorial, vamos usar DQN para treinar um modelo, em seguida, usar Checkpointer e PolicySaver para mostrar como podemos armazenar e carregar os estados e modelo de uma forma interactiva. Note que vamos utilizar novas ferramentas saved_model de TF2.0 e formato para PolicySaver .

Configurar

Se você não instalou as seguintes dependências, execute:

sudo apt-get update
sudo apt-get install -y xvfb ffmpeg python-opengl
pip install pyglet
pip install 'imageio==2.4.0'
pip install 'xvfbwrapper==0.2.9'
pip install tf-agents[reverb]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython

try:
  from google.colab import files
except ImportError:
  files = None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

tempdir = os.getenv("TEST_TMPDIR", tempfile.gettempdir())
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper.Xvfb(1400, 900, 24).start()

Agente DQN

Vamos configurar o agente DQN, assim como na colab anterior. Os detalhes estão ocultos por padrão, pois não são parte central desta colab, mas você pode clicar em 'MOSTRAR CÓDIGO' para ver os detalhes.

Hiperparâmetros

env_name = "CartPole-v1"

collect_steps_per_iteration = 100
replay_buffer_capacity = 100000

fc_layer_params = (100,)

batch_size = 64
learning_rate = 1e-3
log_interval = 5

num_eval_episodes = 10
eval_interval = 1000

Ambiente

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

Agente

Coleção de dados

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py:383: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.

Treine o agente

Geração de Vídeo

Gerar um vídeo

Verifique o desempenho da política gerando um vídeo.

print ('global_step:')
print (global_step)
run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)
global_step:
<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>

gif

Configurar Checkpointer e PolicySaver

Agora estamos prontos para usar Checkpointer e PolicySaver.

Checkpointer

checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
    ckpt_dir=checkpoint_dir,
    max_to_keep=1,
    agent=agent,
    policy=agent.policy,
    replay_buffer=replay_buffer,
    global_step=global_step
)

Política de poupança

policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
2022-01-20 12:15:14.054931: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.

Treine uma iteração

print('Training one iteration....')
train_one_iteration()
Training one iteration....
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
iteration: 1 loss: 1.0214563608169556

Salvar no ponto de verificação

train_checkpointer.save(global_step)

Ponto de verificação de restauração

Para que isso funcione, todo o conjunto de objetos deve ser recriado da mesma forma que quando o ponto de verificação foi criado.

train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

Também salve a política e exporte para um local

tf_policy_saver.save(policy_dir)
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
  "imported and registered." % type_spec_class_name)
INFO:tensorflow:Assets written to: /tmp/policy/assets

A política pode ser carregada sem qualquer conhecimento de qual agente ou rede foi usado para criá-la. Isso torna a implantação da política muito mais fácil.

Carregue a política salva e verifique como ela funciona

saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

gif

Exportar e importar

O resto do colab o ajudará a exportar / importar o checkpointer e diretórios de política de forma que você possa continuar o treinamento em um ponto posterior e implantar o modelo sem ter que treinar novamente.

Agora você pode voltar para 'Treinar uma iteração' e treinar mais algumas vezes para que possa entender a diferença mais tarde. Assim que começar a ver resultados ligeiramente melhores, continue abaixo.

Crie o arquivo zip e carregue o arquivo zip (clique duas vezes para ver o código)

Crie um arquivo compactado do diretório do ponto de verificação.

train_checkpointer.save(global_step)
checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))

Baixe o arquivo zip.

if files is not None:
  files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

Após o treinamento por algum tempo (10-15 vezes), baixe o arquivo zip do ponto de verificação e vá para "Runtime> Reiniciar e executar tudo" para redefinir o treinamento e volte para esta célula. Agora você pode fazer upload do arquivo zip baixado e continuar o treinamento.

upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

Depois de carregar o diretório de pontos de verificação, volte para 'Treinar uma iteração' para continuar o treinamento ou volte para 'Gerar um vídeo' para verificar o desempenho da política carregada.

Como alternativa, você pode salvar a política (modelo) e restaurá-la. Ao contrário do checkpointer, você não pode continuar com o treinamento, mas ainda pode implantar o modelo. Observe que o arquivo baixado é muito menor do que o ponteiro de verificação.

tf_policy_saver.save(policy_dir)
policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
  "imported and registered." % type_spec_class_name)
INFO:tensorflow:Assets written to: /tmp/policy/assets
if files is not None:
  files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

Faça upload do diretório de política baixado (export_policy.zip) e verifique o desempenho da política salva.

upload_and_unzip_file_to(policy_dir)
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

gif

SavedModelPyTFEagerPolicy

Se você não quiser usar a política TF, então você também pode usar o saved_model diretamente com o env Python através do uso de py_tf_eager_policy.SavedModelPyTFEagerPolicy .

Observe que isso só funciona quando o modo ansioso está ativado.

eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
    policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())

# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)

gif

Converter política para TFLite

Veja conversor TensorFlow Lite para mais detalhes.

converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=["action"])
tflite_policy = converter.convert()
with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:
  f.write(tflite_policy)
2022-01-20 12:15:59.646042: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2022-01-20 12:15:59.646082: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2022-01-20 12:15:59.646088: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded

Executar inferência no modelo TFLite

Veja TensorFlow Lite Inference para mais detalhes.

import numpy as np
interpreter = tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))

policy_runner = interpreter.get_signature_runner()
print(policy_runner._inputs)
{'0/discount': 1, '0/observation': 2, '0/reward': 3, '0/step_type': 0}
policy_runner(**{
    '0/discount':tf.constant(0.0),
    '0/observation':tf.zeros([1,4]),
    '0/reward':tf.constant(0.0),
    '0/step_type':tf.constant(0)})
{'action': array([0])}