Checkpointer و PolicySaver

عرض على TensorFlow.org تشغيل في Google Colab عرض المصدر على جيثب تحميل دفتر

مقدمة

tf_agents.utils.common.Checkpointer هو أداة لحفظ / تحميل دولة والتدريب، ودولة السياسة، والدولة replay_buffer إلى / من التخزين المحلي.

tf_agents.policies.policy_saver.PolicySaver هو أداة لحفظ / تحميل فقط السياسة، وأخف من Checkpointer . يمكنك استخدام PolicySaver لنشر نموذج، وكذلك دون أي معرفة من التعليمات البرمجية التي تم إنشاؤها السياسة.

في هذا البرنامج التعليمي، سوف نستخدم DQN لتدريب نموذج، ثم استخدم Checkpointer و PolicySaver لإظهار كيف يمكننا تخزين وتحميل الدول ونموذج بطريقة تفاعلية. علما بأن سوف نستخدم الأدوات saved_model جديدة TF2.0 وتنسيق PolicySaver .

يثبت

إذا لم تكن قد قمت بتثبيت التبعيات التالية ، فقم بتشغيل:

sudo apt-get update
sudo apt-get install -y xvfb ffmpeg python-opengl
pip install pyglet
pip install 'imageio==2.4.0'
pip install 'xvfbwrapper==0.2.9'
pip install tf-agents[reverb]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython

try:
  from google.colab import files
except ImportError:
  files = None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

tempdir = os.getenv("TEST_TMPDIR", tempfile.gettempdir())
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper.Xvfb(1400, 900, 24).start()

وكيل DQN

سنقوم بإعداد وكيل DQN ، تمامًا كما في colab السابق. يتم إخفاء التفاصيل افتراضيًا لأنها ليست جزءًا أساسيًا من هذا الكولاب ، ولكن يمكنك النقر فوق "إظهار الكود" للاطلاع على التفاصيل.

Hyperparameters

env_name = "CartPole-v1"

collect_steps_per_iteration = 100
replay_buffer_capacity = 100000

fc_layer_params = (100,)

batch_size = 64
learning_rate = 1e-3
log_interval = 5

num_eval_episodes = 10
eval_interval = 1000

بيئة

train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)

train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

وكيلات

جمع البيانات

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py:383: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version.
Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.

تدريب الوكيل

توليد الفيديو

قم بإنشاء مقطع فيديو

تحقق من أداء السياسة عن طريق إنشاء مقطع فيديو.

print ('global_step:')
print (global_step)
run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)
global_step:
<tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>

gif

إعداد Checkpointer و PolicySaver

الآن نحن جاهزون لاستخدام Checkpointer و PolicySaver.

Checkpointer

checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
    ckpt_dir=checkpoint_dir,
    max_to_keep=1,
    agent=agent,
    policy=agent.policy,
    replay_buffer=replay_buffer,
    global_step=global_step
)

حافظ السياسة

policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
2022-01-20 12:15:14.054931: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.

تدريب تكرار واحد

print('Training one iteration....')
train_one_iteration()
Training one iteration....
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version.
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))
iteration: 1 loss: 1.0214563608169556

حفظ في نقطة التفتيش

train_checkpointer.save(global_step)

استعادة نقطة التفتيش

لكي يعمل هذا ، يجب إعادة إنشاء مجموعة الكائنات بالكامل بنفس الطريقة التي تم بها إنشاء نقطة التحقق.

train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

أيضا حفظ السياسة والتصدير إلى الموقع

tf_policy_saver.save(policy_dir)
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
  "imported and registered." % type_spec_class_name)
INFO:tensorflow:Assets written to: /tmp/policy/assets

يمكن تحميل السياسة دون معرفة أي وكيل أو شبكة تم استخدامها لإنشائها. هذا يجعل نشر السياسة أسهل بكثير.

قم بتحميل السياسة المحفوظة وتحقق من كيفية أدائها

saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

gif

التصدير والاستيراد

سيساعدك باقي colab على تصدير / استيراد دليل الفحص وأدلة السياسة بحيث يمكنك متابعة التدريب في وقت لاحق ونشر النموذج دون الحاجة إلى التدريب مرة أخرى.

يمكنك الآن العودة إلى "تدريب تكرار واحد" والتدريب بضع مرات أخرى حتى تتمكن من فهم الفرق لاحقًا. بمجرد أن تبدأ في رؤية نتائج أفضل قليلاً ، تابع أدناه.

إنشاء ملف مضغوط وتحميل ملف مضغوط (انقر نقرا مزدوجا لرؤية الرمز)

قم بإنشاء ملف مضغوط من دليل نقاط التفتيش.

train_checkpointer.save(global_step)
checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))

قم بتنزيل الملف المضغوط.

if files is not None:
  files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

بعد التدريب لبعض الوقت (10-15 مرة) ، قم بتنزيل ملف checkpoint zip ، وانتقل إلى "Runtime> Restart and Run all" لإعادة تعيين التدريب ، والعودة إلى هذه الخلية. يمكنك الآن تحميل الملف المضغوط الذي تم تنزيله ومتابعة التدريب.

upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()

بمجرد تحميل دليل نقاط التفتيش ، ارجع إلى "تدريب واحد على التكرار" لمواصلة التدريب أو ارجع إلى "إنشاء فيديو" للتحقق من أداء السياسة المحملة.

بدلاً من ذلك ، يمكنك حفظ السياسة (النموذج) واستعادتها. على عكس checkpointer ، لا يمكنك متابعة التدريب ، ولكن لا يزال بإمكانك نشر النموذج. لاحظ أن الملف الذي تم تنزيله أصغر بكثير من ملف checkpointer.

tf_policy_saver.save(policy_dir)
policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel.
WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/policy/assets
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered.
  "imported and registered." % type_spec_class_name)
INFO:tensorflow:Assets written to: /tmp/policy/assets
if files is not None:
  files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469

قم بتحميل دليل السياسة الذي تم تنزيله (exported_policy.zip) وتحقق من كيفية أداء السياسة المحفوظة.

upload_and_unzip_file_to(policy_dir)
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)

gif

SavedModelPyTFEagerPolicy

إذا كنت لا تريد استخدام سياسة TF، ثم يمكنك أيضا استخدام saved_model مباشرة مع الحياة الفطرية بيثون من خلال استخدام py_tf_eager_policy.SavedModelPyTFEagerPolicy .

لاحظ أن هذا لا يعمل إلا عند تمكين الوضع الدائم.

eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
    policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())

# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)

gif

تحويل السياسة إلى TFLite

انظر TensorFlow لايت تحويل لمزيد من التفاصيل.

converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=["action"])
tflite_policy = converter.convert()
with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:
  f.write(tflite_policy)
2022-01-20 12:15:59.646042: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format.
2022-01-20 12:15:59.646082: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency.
2022-01-20 12:15:59.646088: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges.
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded

تشغيل الاستدلال على نموذج TFLite

انظر TensorFlow لايت الاستدلال لمزيد من التفاصيل.

import numpy as np
interpreter = tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))

policy_runner = interpreter.get_signature_runner()
print(policy_runner._inputs)
{'0/discount': 1, '0/observation': 2, '0/reward': 3, '0/step_type': 0}
policy_runner(**{
    '0/discount':tf.constant(0.0),
    '0/observation':tf.zeros([1,4]),
    '0/reward':tf.constant(0.0),
    '0/step_type':tf.constant(0)})
{'action': array([0])}