![]() | ![]() | ![]() | ![]() |
このドキュメントの紹介tf.estimator
-a高レベルTensorFlowのAPI。 Estimatorは、次のアクションをカプセル化します。
- トレーニング
- 評価
- 予測
- 提供するためのエクスポート
TensorFlowは、いくつかの既成のEstimatorを実装しています。カスタム推定量は引き続きサポートされていますが、主に下位互換性の指標として使用されます。カスタム推定量は、新しいコードには使用しないでください。すべてのEstimator(既製またはカスタムのもの)は、 tf.estimator.Estimator
クラスに基づくクラスです。
簡単な例として、 Estimatorチュートリアルを試してください。 API設計の概要については、ホワイトペーパーを確認してください。
セットアップ
pip install -q -U tensorflow_datasets
import tempfile
import os
import tensorflow as tf
import tensorflow_datasets as tfds
利点
tf.keras.Model
と同様に、 estimator
はモデルレベルの抽象化です。 tf.estimator
は、 tf.keras
用に現在開発中のいくつかの機能を提供します。これらは:
- パラメータサーバーベースのトレーニング
- 完全なTFX統合
推定量の機能
Estimatorには次の利点があります。
- モデルを変更せずに、ローカルホストまたは分散マルチサーバー環境でEstimatorベースのモデルを実行できます。さらに、モデルを再コーディングせずに、CPU、GPU、またはTPUでEstimatorベースのモデルを実行できます。
- Estimatorは、次の方法とタイミングを制御する安全な分散トレーニングループを提供します。
- データを読み込む
- 例外を処理する
- チェックポイントファイルを作成し、障害から回復します
- TensorBoardの概要を保存する
Estimatorを使用してアプリケーションを作成する場合は、データ入力パイプラインをモデルから分離する必要があります。この分離により、さまざまなデータセットでの実験が簡素化されます。
既製の推定量を使用する
事前に作成されたEstimatorを使用すると、基本のTensorFlowAPIよりもはるかに高い概念レベルで作業できます。 Estimatorがすべての「配管」を処理するため、計算グラフやセッションの作成について心配する必要はありません。さらに、事前に作成されたEstimatorを使用すると、最小限のコード変更のみを行うことで、さまざまなモデルアーキテクチャを試すことができます。たとえば、 tf.estimator.DNNClassifier
は、高密度のフィードフォワードニューラルネットワークに基づいて分類モデルをトレーニングする、事前に作成されたEstimatorクラスです。
事前に作成されたEstimatorに依存するTensorFlowプログラムは、通常、次の4つのステップで構成されます。
1.入力関数を書く
たとえば、トレーニングセットをインポートする関数と、テストセットをインポートする関数を作成できます。推定者は、入力がオブジェクトのペアとしてフォーマットされることを期待しています。
- キーが機能名であり、値が対応する機能データを含むテンソル(またはSparseTensors)である辞書
- 1つ以上のラベルを含むテンソル
input_fn
は、その形式でペアを生成するtf.data.Dataset
を返す必要があります。
たとえば、次のコードは、Titanicデータセットのtrain.csv
ファイルからtf.data.Dataset
を構築します。
def train_input_fn():
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic = tf.data.experimental.make_csv_dataset(
titanic_file, batch_size=32,
label_name="survived")
titanic_batches = (
titanic.cache().repeat().shuffle(500)
.prefetch(tf.data.AUTOTUNE))
return titanic_batches
input_fn
実行されるtf.Graph
と直接返すことができる(features_dics, labels)
グラフテンソルを含む対を、これは復帰定数のような単純なケースのエラープローン外です。
2.フィーチャ列を定義します。
各tf.feature_column
は、機能名、そのタイプ、および入力の前処理を識別します。
たとえば、次のスニペットは3つのフィーチャ列を作成します。
- 1つ目は、
age
機能を浮動小数点入力として直接使用します。 - 2つ目は、
class
機能をカテゴリ入力として使用します。 - 3つ目は、
embark_town
をカテゴリ入力として使用しますが、hashing trick
を使用して、オプションを列挙する必要をembark_town
、オプションの数を設定します。
詳細については、機能列のチュートリアルを確認してください。
age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third'])
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)
3.関連する既成のEstimatorをインスタンス化します。
たとえば、 LinearClassifier
という名前の事前に作成されたEstimatorのサンプルインスタンス化をLinearClassifier
ます。
model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=[embark, cls, age],
n_classes=2
)
INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpu27sw9ie', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
詳細については、線形分類器のチュートリアルを参照してください。
4.トレーニング、評価、または推論の方法を呼び出します。
すべてのEstimatorは、 train
、 evaluate
、およびpredict
方法を提供します。
model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv 32768/30874 [===============================] - 0s 0us/step INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1727: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead. warnings.warn('`layer.add_variable` is deprecated and ' WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:134: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpu27sw9ie/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100... INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpu27sw9ie/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100... INFO:tensorflow:Loss for final step: 0.62258995.
result = model.evaluate(train_input_fn, steps=10)
for key, value in result.items():
print(key, ":", value)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-01-08T02:56:30Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpu27sw9ie/model.ckpt-100 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.67613s INFO:tensorflow:Finished evaluation at 2021-01-08-02:56:31 INFO:tensorflow:Saving dict for global step 100: accuracy = 0.715625, accuracy_baseline = 0.60625, auc = 0.7403657, auc_precision_recall = 0.6804854, average_loss = 0.5836128, global_step = 100, label/mean = 0.39375, loss = 0.5836128, precision = 0.739726, prediction/mean = 0.34897345, recall = 0.42857143 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpu27sw9ie/model.ckpt-100 accuracy : 0.715625 accuracy_baseline : 0.60625 auc : 0.7403657 auc_precision_recall : 0.6804854 average_loss : 0.5836128 label/mean : 0.39375 loss : 0.5836128 precision : 0.739726 prediction/mean : 0.34897345 recall : 0.42857143 global_step : 100
for pred in model.predict(train_input_fn):
for key, value in pred.items():
print(key, ":", value)
break
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpu27sw9ie/model.ckpt-100 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. logits : [-0.73942876] logistic : [0.32312906] probabilities : [0.6768709 0.3231291] class_ids : [0] classes : [b'0'] all_class_ids : [0 1] all_classes : [b'0' b'1']
事前に作成されたEstimatorの利点
事前に作成されたEstimatorはベストプラクティスをエンコードし、次の利点を提供します。
- 計算グラフのさまざまな部分を実行する場所を決定し、単一のマシンまたはクラスターに戦略を実装するためのベストプラクティス。
- イベント(要約)の作成と普遍的に役立つ要約のベストプラクティス。
事前に作成されたEstimatorを使用しない場合は、前述の機能を自分で実装する必要があります。
カスタム推定量
すべてのEstimatorの中心は、事前に作成されているかカスタムであるかにmodel_fn
、そのモデル関数model_fn
。これは、トレーニング、評価、および予測のためのグラフを作成するメソッドです。事前に作成されたEstimatorを使用している場合、他の誰かがすでにモデル関数を実装しています。カスタムEstimatorに依存する場合は、モデル関数を自分で作成する必要があります。
Kerasモデルから推定量を作成する
tf.keras.estimator.model_to_estimator
を使用して、既存のtf.keras.estimator.model_to_estimator
モデルをEstimatorに変換できます。これは、モデルコードを最新化したいが、トレーニングパイプラインにEstimatorが必要な場合に役立ちます。
Keras MobileNet V2モデルをインスタンス化し、オプティマイザー、損失、およびメトリックを使用してモデルをコンパイルし、トレーニングを行います。
keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False
estimator_model = tf.keras.Sequential([
keras_mobilenet_v2,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1)
])
# Compile the model
estimator_model.compile(
optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5 9412608/9406464 [==============================] - 0s 0us/step
コンパイルされたKerasモデルからEstimator
を作成します。 Kerasモデルの初期モデル状態は、作成されたEstimator
保持されます。
est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpeaonpwe8 INFO:tensorflow:Using the Keras model provided. /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/backend.py:434: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model. warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and ' INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpeaonpwe8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
導出されたEstimator
は、他のEstimator
と同じように扱います。
IMG_SIZE = 160 # All images will be resized to 160x160
def preprocess(image, label):
image = tf.cast(image, tf.float32)
image = (image/127.5) - 1
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label
def train_input_fn(batch_size):
data = tfds.load('cats_vs_dogs', as_supervised=True)
train_data = data['train']
train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
return train_data
トレーニングするには、Estimatorのトレイン関数を呼び出します。
est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
Downloading and preparing dataset 786.68 MiB (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0... WARNING:absl:1738 images were corrupted and were skipped Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpeaonpwe8/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpeaonpwe8/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting from: /tmp/tmpeaonpwe8/keras/keras_model.ckpt INFO:tensorflow:Warm-starting from: /tmp/tmpeaonpwe8/keras/keras_model.ckpt INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpeaonpwe8/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpeaonpwe8/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6884984, step = 0 INFO:tensorflow:loss = 0.6884984, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpeaonpwe8/model.ckpt. INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpeaonpwe8/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Loss for final step: 0.67705643. INFO:tensorflow:Loss for final step: 0.67705643. <tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f3d7c3822b0>
同様に、評価するには、Estimatorの評価関数を呼び出します。
est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:2325: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. warnings.warn('`Model.state_updates` will be removed in a future version. ' INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:32Z INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:32Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpeaonpwe8/model.ckpt-50 INFO:tensorflow:Restoring parameters from /tmp/tmpeaonpwe8/model.ckpt-50 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 2.42050s INFO:tensorflow:Inference Time : 2.42050s INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:35 INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:35 INFO:tensorflow:Saving dict for global step 50: accuracy = 0.515625, global_step = 50, loss = 0.6688157 INFO:tensorflow:Saving dict for global step 50: accuracy = 0.515625, global_step = 50, loss = 0.6688157 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpeaonpwe8/model.ckpt-50 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpeaonpwe8/model.ckpt-50 {'accuracy': 0.515625, 'loss': 0.6688157, 'global_step': 50}
詳細については、 tf.keras.estimator.model_to_estimator
のドキュメントを参照してください。
Estimatorを使用したオブジェクトベースのチェックポイントの保存
デフォルトでは、Estimatorは、チェックポイントガイドで説明されているオブジェクトグラフではなく、変数名でチェックポイントを保存します。 tf.train.Checkpoint
は名前ベースのチェックポイントを読み取りますが、モデルの一部をEstimatorのmodel_fn
外に移動すると、変数名が変更される場合があります。上位互換性のために、オブジェクトベースのチェックポイントを保存すると、Estimator内でモデルをトレーニングし、モデルの外で使用することが容易になります。
import tensorflow.compat.v1 as tf_compat
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
def model_fn(features, labels, mode):
net = Net()
opt = tf.keras.optimizers.Adam(0.1)
ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
optimizer=opt, net=net)
with tf.GradientTape() as tape:
output = net(features['x'])
loss = tf.reduce_mean(tf.abs(output - features['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
return tf.estimator.EstimatorSpec(
mode,
loss=loss,
train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
ckpt.step.assign_add(1)),
# Tell the Estimator to save "ckpt" in an object-based format.
scaffold=tf_compat.train.Scaffold(saver=ckpt))
tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 4.4040537, step = 0 INFO:tensorflow:loss = 4.4040537, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 35.247967. INFO:tensorflow:Loss for final step: 35.247967. <tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f3d64534518>
tf.train.Checkpoint
は、 model_dir
からEstimatorのチェックポイントをロードできます。
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy() # From est.train(..., steps=10)
10
EstimatorからのSavedModels
Estimatorは、tf.Estimator.export_saved_modelを介してtf.Estimator.export_saved_model
エクスポートしtf.Estimator.export_saved_model
。
input_column = tf.feature_column.numeric_column("x")
estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])
def input_fn():
return tf.data.Dataset.from_tensor_slices(
({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpczwhe6jk WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpczwhe6jk INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpczwhe6jk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpczwhe6jk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpczwhe6jk/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpczwhe6jk/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpczwhe6jk/model.ckpt. INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpczwhe6jk/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Loss for final step: 0.48830828. INFO:tensorflow:Loss for final step: 0.48830828. <tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f3d6452eb00>
Estimator
を保存するには、 serving_input_receiver
を作成する必要があります。この関数は、 tf.Graph
によって受信された生データを解析するtf.Graph
一部を構築します。
tf.estimator.export
モジュールには、これらのreceivers
構築に役立つ関数が含まれています。
次のコードは、 feature_columns
に基づいて、シリアル化されたtf.Example
プロトコルバッファを受け入れるレシーバーを構築します。これは、 tf-servingでよく使用されます。
tmpdir = tempfile.mkdtemp()
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
tf.feature_column.make_parse_example_spec([input_column]))
estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification'] INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification'] INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression'] INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression'] INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict'] INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict'] INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Restoring parameters from /tmp/tmpczwhe6jk/model.ckpt-50 INFO:tensorflow:Restoring parameters from /tmp/tmpczwhe6jk/model.ckpt-50 INFO:tensorflow:Assets added to graph. INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: /tmp/tmp16t8uhub/from_estimator/temp-1610074656/saved_model.pb INFO:tensorflow:SavedModel written to: /tmp/tmp16t8uhub/from_estimator/temp-1610074656/saved_model.pb
Pythonからそのモデルをロードして実行することもできます。
imported = tf.saved_model.load(estimator_path)
def predict(x):
example = tf.train.Example()
example.features.feature["x"].float_list.value.extend([x])
return imported.signatures["predict"](
examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.581246]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.32789052]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.418754, 0.581246]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>} {'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.24376468]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.1321492]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7562353 , 0.24376468]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>}
tf.estimator.export.build_raw_serving_input_receiver_fn
使用すると、 tf.train.Example
ではなく生のテンソルを受け取る入力関数を作成できます。
Estimatorでのtf.distribute.Strategy
使用(限定サポート)
tf.estimator
は、元々非同期パラメーターサーバーアプローチをサポートしていた分散トレーニングTensorFlowAPIです。 tf.estimator
サポートするようにtf.distribute.Strategy
。 tf.estimator
を使用している場合は、コードをほとんど変更せずに分散トレーニングに変更できます。これにより、Estimatorユーザーは、TPUを使用するだけでなく、複数のGPUと複数のワーカーで同期分散トレーニングを実行できるようになりました。ただし、Estimatorでのこのサポートは制限されています。詳細については、以下の「現在サポートされているもの」セクションを確認してください。
Estimatorでtf.distribute.Strategy
を使用することは、 tf.distribute.Strategy
とは少し異なります。 strategy.scope
を使用する代わりに、ストラテジーオブジェクトをEstimatorのRunConfig
ます。
詳細については、配布されたトレーニングガイドを参照してください。
これは、 LinearRegressor
れたEstimator LinearRegressor
とMirroredStrategy
を使用してこれを示すコードスニペットです。
mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
feature_columns=[tf.feature_column.numeric_column('feats')],
optimizer='SGD',
config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Initializing RunConfig with distribution strategies. INFO:tensorflow:Initializing RunConfig with distribution strategies. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Not using Distribute Coordinator. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp4uihzu_a WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp4uihzu_a INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp4uihzu_a', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp4uihzu_a', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
ここでは、事前に作成されたEstimatorを使用しますが、同じコードがカスタムEstimatorでも機能します。 train_distribute
はトレーニングの配布方法を決定し、 eval_distribute
は評価の配布方法を決定します。これは、トレーニングと評価の両方に同じ戦略を使用するKerasとのもう1つの違いです。
これで、入力関数を使用してこのEstimatorをトレーニングおよび評価できます。
def input_fn():
dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp4uihzu_a/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp4uihzu_a/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 1.0, step = 0 INFO:tensorflow:loss = 1.0, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmp4uihzu_a/model.ckpt. INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmp4uihzu_a/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 2.877698e-13. INFO:tensorflow:Loss for final step: 2.877698e-13. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:41Z INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:41Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmp4uihzu_a/model.ckpt-10 INFO:tensorflow:Restoring parameters from /tmp/tmp4uihzu_a/model.ckpt-10 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.26266s INFO:tensorflow:Inference Time : 0.26266s INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:42 INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:42 INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994 INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmp4uihzu_a/model.ckpt-10 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmp4uihzu_a/model.ckpt-10 {'average_loss': 1.4210855e-14, 'label/mean': 1.0, 'loss': 1.4210855e-14, 'prediction/mean': 0.99999994, 'global_step': 10}
ここでEstimatorとKerasのもう1つの違いは、入力処理です。 Kerasでは、データセットの各バッチが複数のレプリカに自動的に分割されます。ただし、Estimatorでは、自動バッチ分割を実行したり、異なるワーカー間でデータを自動的にシャーディングしたりすることはありません。データをワーカーとデバイスに分散する方法を完全に制御できます。また、データを分散する方法を指定するためにinput_fn
を指定する必要があります。
input_fn
はワーカーごとに1回呼び出されるため、ワーカーごとに1つのデータセットが提供されます。次に、そのデータセットの1つのバッチがそのワーカーの1つのレプリカに供給されるため、1つのワーカーのN個のレプリカに対してN個のバッチが消費されます。つまり、 input_fn
によって返されるデータセットは、サイズPER_REPLICA_BATCH_SIZE
バッチを提供する必要があります。また、ステップのグローバルバッチサイズは、 PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync
として取得できます。
マルチワーカートレーニングを実行するときは、データをワーカー間で分割するか、それぞれにランダムシードを使用してシャッフルする必要があります。これを行う方法の例は、Estimatorを使用したマルチワーカートレーニングチュートリアルで確認できます。
同様に、マルチワーカーおよびパラメーターサーバー戦略も使用できます。コードは同じままですが、 tf.estimator.train_and_evaluate
を使用して、クラスターで実行されているバイナリごとにTF_CONFIG
環境変数を設定する必要があります。
現在サポートされているものは何ですか?
TPUStrategy
を除くすべての戦略を使用したEstimatorでのトレーニングのサポートは限られていTPUStrategy
。基本的なトレーニングと評価は機能するはずですが、 v1.train.Scaffold
などの多くの高度な機能は機能しません。この統合には多くのバグがある可能性があり、このサポートを積極的に改善する予定はありません(Kerasとカスタムトレーニングループのサポートに焦点が当てられています)。可能であれば、代わりにこれらのAPIでtf.distribute
を使用することをおtf.distribute
します。
トレーニングAPI | MirroredStrategy | TPUStrategy | MultiWorkerMirroredStrategy | CentralStorageStrategy | ParameterServerStrategy |
---|---|---|---|---|---|
Estimator API | 限定的なサポート | サポートされていません | 限定的なサポート | 限定的なサポート | 限定的なサポート |
例とチュートリアル
Estimatorでさまざまな戦略を使用する方法を示すエンドツーエンドの例を次に示します。
- Estimatorを使用したマルチワーカートレーニングチュートリアルでは、MNISTデータセットで
MultiWorkerMirroredStrategy
を使用して複数のワーカーでトレーニングする方法を示します。 - Kubernetesテンプレートを使用して、
tensorflow/ecosystem
で分散戦略を使用してマルチワーカートレーニングを実行するエンドツーエンドの例。tf.keras.estimator.model_to_estimator
モデルから開始し、tf.keras.estimator.model_to_estimator
を使用してEstimatorに変換します。 -
MirroredStrategy
またはMultiWorkerMirroredStrategy
いずれかを使用してトレーニングできる公式のResNet50モデル。