このページは Cloud Translation API によって翻訳されました。
Switch to English

推定者

TensorFlow.orgで見る Google Colabで実行 GitHubでソースを表示する ノートブックをダウンロード

このドキュメントの紹介tf.estimator -a高レベルTensorFlowのAPI。エスティメータは次のアクションをカプセル化します。

  • トレーニング
  • 評価
  • 予測
  • サービングのためのエクスポート

私たちが提供する既成の見積もりを使用するか、独自のカスタム見積もりを作成することができます。すべてのEstimator(既製でもカスタムでも)は、 tf.estimator.Estimatorクラスに基づくクラスです。

簡単な例として、 Estimatorチュートリアルを試してください 。 API設計の概要については、 ホワイトペーパーを参照してください。

メリット

tf.keras.Modelと同様に、 estimatorはモデルレベルの抽象化です。 tf.estimatorは、現在tf.kerasために現在開発中のいくつかの機能を提供します。これらは:

  • パラメータサーバーベースのトレーニング
  • TFXの完全な統合。

見積もり機能

見積もりには次の利点があります。

  • モデルを変更せずに、ローカルホストまたは分散マルチサーバー環境でEstimatorベースのモデルを実行できます。さらに、モデルを再コーディングせずに、CPU、GPU、またはTPUでEstimatorベースのモデルを実行できます。
  • エスティメータは、次の方法とタイミングを制御する安全な分散トレーニングループを提供します。
    • データを読み込む
    • 例外を処理する
    • チェックポイントファイルを作成し、障害から回復する
    • TensorBoardの概要を保存する

Estimatorを使用してアプリケーションを作成する場合は、データ入力パイプラインをモデルから分離する必要があります。この分離により、さまざまなデータセットでの実験が簡素化されます。

既成の見積もりを使用する

事前に作成されたEstimatorを使用すると、ベースのTensorFlow APIよりもはるかに高い概念レベルで作業できます。 Estimatorがすべての「配管」を処理するため、計算グラフやセッションの作成について心配する必要がなくなります。さらに、事前に作成されたEstimatorを使用すると、最小限のコード変更だけでさまざまなモデルアーキテクチャを試すことができます。たとえば、 tf.estimator.DNNClassifierは、密なフィードフォワードニューラルネットワークに基づいて分類モデルをトレーニングする既成のEstimatorクラスです。

事前に作成されたEstimatorに依存するTensorFlowプログラムは、通常、次の4つのステップで構成されます。

1.入力関数を書く

たとえば、トレーニングセットをインポートする関数とテストセットをインポートする関数を作成できます。推定器は、入力がオブジェクトのペアとしてフォーマットされることを期待します。

  • キーが特徴名で値が対応する特徴データを含むTensors(またはSparseTensors)である辞書
  • 1つ以上のラベルを含むTensor

input_fnは、その形式でペアを生成するtf.data.Datasetを返す必要があります。

たとえば、次のコードは、Titanicデータセットのtrain.csvファイルからtf.data.Datasetを構築します。

import tensorflow as tf

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.experimental.AUTOTUNE))
  return titanic_batches

input_fn実行されるtf.Graphと直接返すことができる(features_dics, labels)グラフテンソルを含む対を、これは復帰定数のような単純なケースのエラープローン外です。

2.特徴列を定義します。

tf.feature_columnは、機能名、そのタイプ、および入力前処理を識別します。

たとえば、次のスニペットは3つの機能列を作成します。

  • 1つ目は、 age機能を直接浮動小数点入力として使用します。
  • 2つ目は、 class特徴をカテゴリカル入力として使用します。
  • 3つ目はembark_townをカテゴリ入力として使用しますが、 hashing trickを使用して、オプションを列挙する必要をembark_town 、オプションの数を設定します。

詳細については、 機能列のチュートリアルを参照してください

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3.関連する既製のEstimatorをインスタンス化します。

たとえば、次に示すのは、 LinearClassifierという名前のLinearClassifier Estimatorのインスタンス化のサンプルLinearClassifier

import tempfile
model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpbsi4iylb', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

詳細については、 線形分類子チュートリアルを参照してください

4.トレーニング、評価、または推論メソッドを呼び出します。

すべてのEstimatorは、 trainevaluate 、およびpredict方法を提供します。

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:1481: Layer.add_variable (from tensorflow.python.keras.engine.base_layer_v1) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:112: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpbsi4iylb/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpbsi4iylb/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.6262238.

result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-09-19T01:23:19Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpbsi4iylb/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.60420s
INFO:tensorflow:Finished evaluation at 2020-09-19-01:23:20
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.709375, accuracy_baseline = 0.634375, auc = 0.7508315, auc_precision_recall = 0.6325826, average_loss = 0.5779631, global_step = 100, label/mean = 0.365625, loss = 0.5779631, precision = 0.77272725, prediction/mean = 0.30943406, recall = 0.2905983
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpbsi4iylb/model.ckpt-100
accuracy : 0.709375
accuracy_baseline : 0.634375
auc : 0.7508315
auc_precision_recall : 0.6325826
average_loss : 0.5779631
label/mean : 0.365625
loss : 0.5779631
precision : 0.77272725
prediction/mean : 0.30943406
recall : 0.2905983
global_step : 100

for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpbsi4iylb/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [-0.5147792]
logistic : [0.37407383]
probabilities : [0.62592614 0.37407383]
class_ids : [0]
classes : [b'0']
all_class_ids : [0 1]
all_classes : [b'0' b'1']

既製の見積もりの​​利点

事前に作成されたEstimatorはベストプラクティスをエンコードし、次の利点を提供します。

  • 計算グラフのさまざまな部分を実行する場所を決定し、単一のマシンまたはクラスターに戦略を実装するためのベストプラクティス。
  • イベント(概要)を作成するためのベストプラクティスと一般的に役立つ概要。

既成のEstimatorを使用しない場合は、前述の機能を自分で実装する必要があります。

カスタム推定量

すべてのEstimatorの中心は、既製でもカスタムでも、そのモデル関数です 。これは、トレーニング、評価、予測のためのグラフを作成する方法です。既製のEstimatorを使用している場合、他の誰かがすでにモデル関数を実装しています。カスタムEstimatorを使用する場合は、モデル関数を自分で作成する必要があります。

したがって、推奨されるワークフローは次のとおりです。

  1. 適切な既成のEstimatorが存在すると仮定して、それを使用して最初のモデルを構築し、その結果を使用してベースラインを確立します。
  2. この既成のEstimatorを使用して、データの整合性と信頼性を含む、全体的なパイプラインを構築およびテストします。
  3. 適切な代替の事前に作成されたEstimatorが利用可能な場合は、実験を実行して、どの事前作成されたEstimatorが最良の結果を生成するかを決定します。
  4. おそらく、独自のカスタムEstimatorを作成して、モデルをさらに改善します。

Kerasモデルから推定器を作成する

tf.keras.estimator.model_to_estimatorを使用して、既存のtf.keras.estimator.model_to_estimatorモデルをEstimatorに変換できます。そうすることで、Kerasモデルが分散トレーニングなどのEstimatorの強みにアクセスできるようになります。

Keras MobileNet V2モデルをインスタンス化し、トレーニングするオプティマイザ、損失、メトリックを使用してモデルをコンパイルします。

import tensorflow as tf
import tensorflow_datasets as tfds
keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

コンパイルされたKerasモデルからEstimatorを作成します。 Kerasモデルの初期モデル状態は、作成されたEstimator保持されます。

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp6jubajxm
INFO:tensorflow:Using the Keras model provided.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py:220: set_learning_phase (from tensorflow.python.keras.backend) is deprecated and will be removed after 2020-10-11.
Instructions for updating:
Simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp6jubajxm', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

他のEstimatorと同様に、派生Estimatorを扱います。

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

トレーニングするには、Estimatorのtrain関数を呼び出します。

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=500)
Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...

Warning:absl:1738 images were corrupted and were skipped

Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incomplete1VDT6U/cats_vs_dogs-train.tfrecord
Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp6jubajxm/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp6jubajxm/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting from: /tmp/tmp6jubajxm/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting from: /tmp/tmp6jubajxm/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp6jubajxm/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp6jubajxm/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.69570315, step = 0

INFO:tensorflow:loss = 0.69570315, step = 0

INFO:tensorflow:global_step/sec: 23.7702

INFO:tensorflow:global_step/sec: 23.7702

INFO:tensorflow:loss = 0.6749977, step = 100 (4.209 sec)

INFO:tensorflow:loss = 0.6749977, step = 100 (4.209 sec)

INFO:tensorflow:global_step/sec: 26.3102

INFO:tensorflow:global_step/sec: 26.3102

INFO:tensorflow:loss = 0.69422066, step = 200 (3.800 sec)

INFO:tensorflow:loss = 0.69422066, step = 200 (3.800 sec)

INFO:tensorflow:global_step/sec: 25.8025

INFO:tensorflow:global_step/sec: 25.8025

INFO:tensorflow:loss = 0.5480626, step = 300 (3.876 sec)

INFO:tensorflow:loss = 0.5480626, step = 300 (3.876 sec)

INFO:tensorflow:global_step/sec: 25.8782

INFO:tensorflow:global_step/sec: 25.8782

INFO:tensorflow:loss = 0.6515007, step = 400 (3.865 sec)

INFO:tensorflow:loss = 0.6515007, step = 400 (3.865 sec)

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500...

INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmp6jubajxm/model.ckpt.

INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmp6jubajxm/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500...

INFO:tensorflow:Loss for final step: 0.7829801.

INFO:tensorflow:Loss for final step: 0.7829801.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f87c0520518>

同様に、評価するには、Estimatorのevaluate関数を呼び出します。

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Starting evaluation at 2020-09-19T01:24:35Z

INFO:tensorflow:Starting evaluation at 2020-09-19T01:24:35Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/tmp6jubajxm/model.ckpt-500

INFO:tensorflow:Restoring parameters from /tmp/tmp6jubajxm/model.ckpt-500

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Inference Time : 1.74243s

INFO:tensorflow:Inference Time : 1.74243s

INFO:tensorflow:Finished evaluation at 2020-09-19-01:24:37

INFO:tensorflow:Finished evaluation at 2020-09-19-01:24:37

INFO:tensorflow:Saving dict for global step 500: accuracy = 0.61875, global_step = 500, loss = 0.6329565

INFO:tensorflow:Saving dict for global step 500: accuracy = 0.61875, global_step = 500, loss = 0.6329565

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmp/tmp6jubajxm/model.ckpt-500

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmp/tmp6jubajxm/model.ckpt-500

{'accuracy': 0.61875, 'loss': 0.6329565, 'global_step': 500}

詳細については、 tf.keras.estimator.model_to_estimatorのドキュメントを参照してください。