Estimator

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

警告: 新しいコードには Estimators は推奨されません。Estimators は v1.Session スタイルのコードを実行しますが、これは正しく記述するのはより難しく、特に TF 2 コードと組み合わせると予期しない動作をする可能性があります。Estimators は、互換性保証の対象となりますが、セキュリティの脆弱性以外の修正は行われません。詳細については、移行ガイドを参照してください。

このドキュメントでは、tf.estimator という高位 TensorFlow API を紹介します。Estimator は以下のアクションをカプセル化します。

  • トレーニング
  • 評価
  • 予測
  • 配信向けエクスポート

TensorFlow は、事前に作成された複数の Estimator を実装します。カスタムの Estimator は依然としてサポートされていますが、主に下位互換性の対策としてサポートされているため、新しいコードでは、カスタム Estimator を使用してはいけません。事前に作成された Estimator とカスタム Estimator はすべて、tf.estimator.Estimator クラスに基づくクラスです。

簡単な例については、Estimator チュートリアルを試してください。API デザインの概要については、ホワイトペーパーをご覧ください。

セットアップ

pip install -U tensorflow_datasets
import tempfile
import os

import tensorflow as tf
import tensorflow_datasets as tfds
2022-08-09 06:11:25.091238: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-09 06:11:25.776909: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-09 06:11:25.777149: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-09 06:11:25.777161: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

メリット

tf.keras.Model と同様に、estimator はモデルレベルの抽象です。tf.estimator は、tf.keras 向けに現在開発段階にある以下の機能を提供しています。

  • パラメーターサーバーベースのトレーニング
  • TFX の完全統合

Estimator の機能

Estimator には以下のメリットがあります。

  • Estimator ベースのモデルは、モデルを変更することなくローカルホストまたは分散マルチサーバー環境で実行できます。さらに、モデルをコーディングし直すことなく、CPU、GPU、または TPU で実行できます。
  • Estimator では、次を実行する方法とタイミングを制御する安全な分散型トレーニングループを使用できます。
    • データの読み込み
    • 例外の処理
    • チェックポイントファイルの作成と障害からの復旧
    • TensorBoard 用のサマリーの保存

Estimator を使ってアプリケーションを記述する場合、データ入力パイプラインとモデルを分離する必要があります。分離することで、異なるデータセットを伴う実験を単純化することができます。

事前作成済み Estimator を使用する

既成の Estimator を使うと、基本の TensorFlow API より非常に高い概念レベルで作業することができます。Estimator がすべての「配管作業」を処理してくれるため、計算グラフやセッションの作成などに気を回す必要がありません。さらに、事前作成済みの Estimator では、コード変更を最小限に抑えて多様なモデルアーキテクチャを使った実験を行えます。たとえば tf.estimator.DNNClassifier は、密度の高いフィードフォワードのニューラルネットワークに基づく分類モデルをトレーニングする事前作成済みの Estimator クラスです。

事前作成済み Estimator に依存する TensorFlow プログラムは、通常、次の 4 つのステップで構成されています。

1. 入力関数を作成する

たとえば、トレーニングセットをインポートする関数とテストセットをインポートする関数を作成する場合、Estimator は入力が次の 2 つのオブジェクトのペアとしてフォーマットされていることを期待します。

  • 特徴名のキーと対応する特徴データを含むテンソル(または SparseTensors)の値で構成されるディクショナリ
  • 1 つ以上のラベルを含むテンソル

input_fn は上記のフォーマットのペアを生成する tf.data.Dataset を返します。

たとえば、次のコードは Titanic データセットの train.csv ファイルから tf.data.Dataset を構築します。

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.AUTOTUNE))
  return titanic_batches

input_fn は、tf.Graph で実行し、グラフテンソルを含む (features_dics, labels) ペアを直接返すこともできますが、定数を返すといった単純なケースではない場合に、エラーが発生しやすくなります。

2. 特徴量カラムを定義する

tf.feature_column は、特徴量名、その型、およびすべての入力前処理を特定します。

たとえば、次のスニペットは 3 つの特徴量カラムを作成します。

  • 最初の特徴量カラムは、浮動小数点数の入力として直接 age 特徴量を使用します。
  • 2 つ目の特徴量カラムは、カテゴリカル入力として class 特徴量を使用します。
  • 3 つ目の特徴量カラムは、カテゴリカル入力として embark_town を使用しますが、オプションを列挙する必要がないように、またオプション数を設定するために、hashing trick を使用します。

詳細については、特徴量カラムのチュートリアルをご覧ください。

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3. 関連する事前作成済み Estimator をインスタンス化する

LinearClassifier という事前作成済み Estimator のインスタンス化の例を次に示します。

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp9zht3zpu', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

詳細については、線形分類器のチュートリアルをご覧ください。

4. トレーニング、評価、または推論メソッドを呼び出す

すべての Estimator には、trainevaluate、および predict メソッドがあります。

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_v2/ftrl.py:153: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp9zht3zpu/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmpfs/tmp/tmp9zht3zpu/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.68833876.
2022-08-09 06:11:32.339278: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.checkpoint_management has been moved to tensorflow.python.checkpoint.checkpoint_management. The old module will be deleted in version 2.9.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-08-09T06:11:33
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9zht3zpu/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.71680s
INFO:tensorflow:Finished evaluation at 2022-08-09-06:11:33
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.634375, accuracy_baseline = 0.603125, auc = 0.72236955, auc_precision_recall = 0.6232533, average_loss = 0.6130203, global_step = 100, label/mean = 0.396875, loss = 0.6130203, precision = 0.5367647, prediction/mean = 0.45684123, recall = 0.5748032
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmpfs/tmp/tmp9zht3zpu/model.ckpt-100
accuracy : 0.634375
accuracy_baseline : 0.603125
auc : 0.72236955
auc_precision_recall : 0.6232533
average_loss : 0.6130203
label/mean : 0.396875
loss : 0.6130203
precision : 0.5367647
prediction/mean : 0.45684123
recall : 0.5748032
global_step : 100
2022-08-09 06:11:33.863142: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9zht3zpu/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [0.55210924]
logistic : [0.63462484]
probabilities : [0.3653752  0.63462484]
class_ids : [1]
classes : [b'1']
all_class_ids : [0 1]
all_classes : [b'0' b'1']
2022-08-09 06:11:34.696157: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

事前作成済み Estimator のメリット

事前作成済み Estimator は、次のようなベストプラクティスをエンコードするため、さまざまなメリットがあります。

  • さまざまな部分の計算グラフをどこで実行するかを決定し、単一のマシンまたはクラスタに戦略を実装するためのベストプラクティス。
  • イベント(要約)の書き込みと普遍的に役立つ要約のベストプラクティス。

事前作成済み Estimator を使用しない場合は、上記の特徴量を独自に実装する必要があります。

カスタム Estimator

事前作成済みかカスタムかに関係なく、すべての Estimator の中核は、モデル関数model_fn にあります。これは、トレーニング、評価、および予測に使用するグラフを構築するメソッドです。事前作成済み Estimator を使用する場合は、モデル関数はすでに実装されていますが、カスタム Estimator を使用する場合は、モデル関数を自分で記述する必要があります。

注意: カスタム model_fn は 1.x スタイルのグラフモードでそのまま実行します。つまり、Eager execution はなく、依存関係の自動制御もないため、tf.estimator からカスタム model_fn に移行する必要があります。代替の API は tf.kerastf.distribute です。トレーニングの一部に Estimator を使用する必要がある場合は、tf.keras.estimator.model_to_estimator コンバータを使用して keras.Model から Estimator を作成する必要があります。

Keras モデルから Estimator を作成する

tf.keras.estimator.model_to_estimator を使用して、既存の Keras モデルを Estimator に変換できます。モデルコードを最新の状態に変更したくても、トレーニングパイプラインに Estimator が必要な場合に役立ちます。

Keras MobileNet V2 モデルをインスタンス化し、トレーニングに使用する optimizer、loss、および metrics とともにモデルをコンパイルします。

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9406464/9406464 [==============================] - 0s 0us/step

コンパイルされた Keras モデルから Estimator を作成します。Keras モデルの初期化状態が、作成した Estimator に維持されます。

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp664xathg
INFO:tensorflow:Using the Keras model provided.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/backend.py:450: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp664xathg', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

派生した Estimator をほかの Estimator と同じように扱います。

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

トレーニングするには、Estimator の train 関数を呼び出します。

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.util has been moved to tensorflow.python.checkpoint.checkpoint. The old module will be deleted in version 2.11.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.util has been moved to tensorflow.python.checkpoint.checkpoint. The old module will be deleted in version 2.11.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmpfs/tmp/tmp664xathg/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmpfs/tmp/tmp664xathg/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmpfs/tmp/tmp664xathg/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmpfs/tmp/tmp664xathg/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp664xathg/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp664xathg/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6884295, step = 0
INFO:tensorflow:loss = 0.6884295, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Saving checkpoints for 50 into /tmpfs/tmp/tmp664xathg/model.ckpt.
INFO:tensorflow:Saving checkpoints for 50 into /tmpfs/tmp/tmp664xathg/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Loss for final step: 0.6258749.
INFO:tensorflow:Loss for final step: 0.6258749.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fa0bdd900d0>

同様に、評価するには、Estimator の evaluate 関数を呼び出します。

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/engine/training_v1.py:2045: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  updates = self.state_updates
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-08-09T06:11:52
INFO:tensorflow:Starting evaluation at 2022-08-09T06:11:52
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp664xathg/model.ckpt-50
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp664xathg/model.ckpt-50
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 2.76210s
INFO:tensorflow:Inference Time : 2.76210s
INFO:tensorflow:Finished evaluation at 2022-08-09-06:11:55
INFO:tensorflow:Finished evaluation at 2022-08-09-06:11:55
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.58125, global_step = 50, loss = 0.6781251
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.58125, global_step = 50, loss = 0.6781251
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmpfs/tmp/tmp664xathg/model.ckpt-50
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmpfs/tmp/tmp664xathg/model.ckpt-50
{'accuracy': 0.58125, 'loss': 0.6781251, 'global_step': 50}

詳細については、tf.keras.estimator.model_to_estimator のドキュメントを参照してください。

Estimator でオブジェクトベースのチェックポイントを保存する

Estimator はデフォルトで、チェックポイントガイドで説明したオブジェクトグラフではなく、変数名でチェックポイントを保存します。tf.train.Checkpoint は名前ベースのチェックポイントを読み取りますが、モデルの一部を Estimator の model_fn の外側に移動すると変数名が変わることがあります。上位互換性においては、オブジェクトベースのチェックポイントを保存すると、Estimator の内側でモデルをトレーニングし、外側でそれを使用することが容易になります。

import tensorflow.compat.v1 as tf_compat
def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.536174, step = 0
INFO:tensorflow:loss = 4.536174, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 37.49402.
INFO:tensorflow:Loss for final step: 37.49402.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fa0445b3220>

その後、tf.train.Checkpoint は Estimator のチェックポイントをその model_dir から読み込むことができます。

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

Estimator の SavedModel

Estimator は、tf.Estimator.export_saved_model によって SavedModel をエクスポートします。

input_column = tf.feature_column.numeric_column("x")

estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpmfyc6apc
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpmfyc6apc
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpmfyc6apc', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpmfyc6apc', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpmfyc6apc/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpmfyc6apc/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Saving checkpoints for 50 into /tmpfs/tmp/tmpmfyc6apc/model.ckpt.
INFO:tensorflow:Saving checkpoints for 50 into /tmpfs/tmp/tmpmfyc6apc/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Loss for final step: 0.44521567.
INFO:tensorflow:Loss for final step: 0.44521567.
<tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7fa0441279a0>

Estimator を保存するには、serving_input_receiver を作成する必要があります。この関数は、SavedModel が受け取る生データを解析する tf.Graph の一部を構築します。

tf.estimator.export モジュールには、これらの receivers を構築するための関数が含まれています。

次のコードは、feature_columns に基づき、tf-serving と合わせて使用されることの多いシリアル化された tf.Example プロトコルバッファを受け入れるレシーバーを構築します。

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:146: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:146: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/model_utils/export_utils.py:84: get_tensor_from_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info or tf.compat.v1.saved_model.get_tensor_from_tensor_info.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/model_utils/export_utils.py:84: get_tensor_from_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info or tf.compat.v1.saved_model.get_tensor_from_tensor_info.
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpmfyc6apc/model.ckpt-50
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpmfyc6apc/model.ckpt-50
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /tmpfs/tmp/tmp6g68lnpf/from_estimator/temp-1660025517/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmpfs/tmp/tmp6g68lnpf/from_estimator/temp-1660025517/saved_model.pb

また、Python からモデルを読み込んで実行することも可能です。

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.31333292]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.5776986]], dtype=float32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.4223014, 0.5776986]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>}
{'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.2550488]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2218274]], dtype=float32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7781726, 0.2218274]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>}

tf.estimator.export.build_raw_serving_input_receiver_fn を使用すると、tf.train.Example の代わりに生のテンソルを取る入力関数を作成することができます。

Estimator を使った tf.distribute.Strategy の使用(制限サポート)

tf.estimator は、もともと非同期パラメーターサーバー手法をサポートしていた分散型トレーニング TensorFlow API です。tf.estimator は現在では tf.distribute.Strategy をサポートするようになっています。tf.estimator を使用している場合は、コードを少し変更するだけで、分散型トレーニングに変更することができます。これにより、Estimator ユーザーは複数の GPU と複数のワーカーだけでなく、TPU でも同期分散型トレーニングを実行できるようになりましたが、Estimator でのこのサポートには制限があります。詳細については、以下に示す「現在、何がサポートされていますか」セクションをご覧ください。

Estimator での tf.distribute.Strategy の使用は、Keras の事例とわずかに異なります。strategy.scope を使用する代わりに、ストラテジーオブジェクトを Estimator の RunConfig に渡します。

詳細については、分散型トレーニングガイドをご覧ください。

次は、事前に作成された Estimator LinearRegressorMirroredStrategy を使ってこの動作を示すコードスニペットです。

mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
    feature_columns=[tf.feature_column.numeric_column('feats')],
    optimizer='SGD',
    config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp_1412dfu
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp_1412dfu
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp_1412dfu', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fa0bdbe46d0>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fa0bdbe46d0>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp_1412dfu', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fa0bdbe46d0>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7fa0bdbe46d0>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

ここでは、事前に作成された Estimator が使用されていますが、同じコードはカスタム Estimator でも動作します。train_distribute はトレーニングの分散方法を判定し、eval_distribute は評価の分散方法を判定します。この点も、トレーニングと評価に同じストラテジーを使用する Keras と異なるところです。

入力関数を使用して、この Estimator をトレーニングし、評価することができます。

def input_fn():
  dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
  return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1244: StrategyBase.configure (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version.
Instructions for updating:
use `update_config_proto` instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1244: StrategyBase.configure (from tensorflow.python.distribute.distribute_lib) is deprecated and will be removed in a future version.
Instructions for updating:
use `update_config_proto` instead.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.graph_view has been moved to tensorflow.python.checkpoint.graph_view. The old module will be deleted in version 2.11.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:461: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.graph_view has been moved to tensorflow.python.checkpoint.graph_view. The old module will be deleted in version 2.11.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:batch_all_reduce: 2 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 2 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.v1.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.v1.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp_1412dfu/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp_1412dfu/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
2022-08-09 06:12:02.525862: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2022-08-09 06:12:02.527102: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'

2022-08-09 06:12:02.545138: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2022-08-09 06:12:02.545612: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'

2022-08-09 06:12:02.552618: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2022-08-09 06:12:02.553082: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'

2022-08-09 06:12:02.563480: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2022-08-09 06:12:02.563974: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'
INFO:tensorflow:loss = 4.0, step = 0
INFO:tensorflow:loss = 4.0, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmpfs/tmp/tmp_1412dfu/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into /tmpfs/tmp/tmp_1412dfu/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 1.1510792e-12.
INFO:tensorflow:Loss for final step: 1.1510792e-12.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-08-09T06:12:04
INFO:tensorflow:Starting evaluation at 2022-08-09T06:12:04
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp_1412dfu/model.ckpt-10
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp_1412dfu/model.ckpt-10
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
2022-08-09 06:12:05.343276: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2022-08-09 06:12:05.344484: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'

2022-08-09 06:12:05.353191: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2022-08-09 06:12:05.353672: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'

2022-08-09 06:12:05.362855: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2022-08-09 06:12:05.363336: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'

2022-08-09 06:12:05.373979: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2022-08-09 06:12:05.374499: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.67275s
INFO:tensorflow:Inference Time : 0.67275s
INFO:tensorflow:Finished evaluation at 2022-08-09-06:12:05
INFO:tensorflow:Finished evaluation at 2022-08-09-06:12:05
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmpfs/tmp/tmp_1412dfu/model.ckpt-10
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmpfs/tmp/tmp_1412dfu/model.ckpt-10
{'average_loss': 1.4210855e-14,
 'label/mean': 1.0,
 'loss': 1.4210855e-14,
 'prediction/mean': 0.99999994,
 'global_step': 10}

Estimator と Keras のもう 1 つの違いとして強調すべき点は入力の処理方法です。Keras では、データセットの各バッチは複数のレプリカに自動的に分断されますが、Estimator の場合は、バッチの自動分断やワーカーをまたいで自動的にシャーディングすることもありません。ワーカーやデバイスでのデータの分散方法はユーザーが完全に制御するものであるため、input_fn を提供してデータの分散方法を指定する必要があります。

input_fn はワーカー当たり一度呼び出されるため、ワーカー当たり 1 つのデータセットが与えられます。次に、そのデータセットの 1 つのバッチがそのワーカーの 1 つのレプリカに供給され、したがって、1 つのワーカーの N 個のレプリカに対して N 個のバッチが消費されることになります。言い換えると、input_fn が返すデータセットは、サイズ PER_REPLICA_BATCH_SIZE のバッチを提供するということです。ステップのグローバルバッチサイズは、PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync として取得することができます。

マルチワーカートレーニングを行う場合は、データをワーカー間で分割するか、それぞれにランダムシードを使用してシャッフルする必要があります。これを行う方法の例は、「Estimator を使ったマルチワーカートレーニング」を参照してください。

また、同様に、マルチワーカーとパラメーターサーバーストラテジーを使用することができます。コードは変わりませんが、tf.estimator.train_and_evaluate を使用し、クラスタで実行している各バイナリの TF_CONFIG 環境変数を設定する必要があります。

現在、何がサポートされていますか?

TPUStrategy を除くすべてのストラテジーを使った Estimator でのトレーニングのサポートには制限があります。基本的なトレーニングと評価は機能しますが、v1.train.Scaffold などの多数の高度な機能はまだ機能しません。また、この統合には多数のバグも存在する可能性があります。現時点では、Keras とカスタムトレーニングループのサポートに注力しているため、このサポートを積極的に改善する予定はありません。可能な限り、それらの API で tf.distribute を使用するようにしてください。

トレーニング API MirroredStrategy TPUStrategy MultiWorkerMirroredStrategy CentralStorageStrategy ParameterServerStrategy
Estimator API 制限サポート 未サポート 制限サポート 制限サポート 制限サポート

例とチュートリアル

次は、Estimator によるさまざまなストラテジーの使用方法を示す、エンドツーエンドの例です。

  1. Estimator を使ったマルチワーカートレーニングのチュートリアルには、MNIST データセットで MultiWorkerMirroredStrategy を使って複数のワーカーをトレーニングする方法が説明されています。
  2. Kubernetes テンプレートを使った tensorflow/ecosystem分散ストラテジーによってマルチワーカートレーニングを実行するエンドツーエンドの例。Keras モデルから始め、tf.keras.estimator.model_to_estimator API を使って Estimator に変換します。
  3. ResNet50 の公式モデル。MirroredStrategy または MultiWorkerMirroredStrategy を使ってトレーニングできます。