Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

Estimator

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Dokumen ini memperkenalkan tf.estimator — API TensorFlow level tinggi. Estimator merangkum tindakan berikut:

  • latihan
  • evaluasi
  • ramalan
  • ekspor untuk disajikan

TensorFlow mengimplementasikan beberapa Estimator siap pakai. Estimator kustom masih didukung, tetapi terutama sebagai ukuran kompatibilitas mundur. Estimator khusus tidak boleh digunakan untuk kode baru . Semua Estimator - baik yang dibuat sebelumnya atau khusus - adalah kelas yang didasarkan pada kelas tf.estimator.Estimator .

Untuk contoh cepat, coba tutorial Estimator . Untuk ikhtisar desain API, lihat kertas putih .

Mendirikan

 pip install -q -U tensorflow_datasets
import tempfile
import os

import tensorflow as tf
import tensorflow_datasets as tfds

Keuntungan

Mirip dengan tf.keras.Model , estimator adalah abstraksi tingkat model. tf.estimator menyediakan beberapa kemampuan yang saat ini masih dalam pengembangan untuk tf.keras . Ini adalah:

  • Pelatihan berbasis server parameter
  • Integrasi TFX penuh.

Kemampuan Estimator

Estimator memberikan manfaat berikut:

  • Anda dapat menjalankan model berbasis Estimator di host lokal atau di lingkungan multi-server terdistribusi tanpa mengubah model Anda. Selain itu, Anda dapat menjalankan model berbasis Estimator pada CPU, GPU, atau TPU tanpa mengode ulang model Anda.
  • Estimator menyediakan loop pelatihan terdistribusi yang aman yang mengontrol bagaimana dan kapan harus:
    • memuat data
    • menangani pengecualian
    • membuat file pos pemeriksaan dan memulihkan dari kegagalan
    • simpan ringkasan untuk TensorBoard

Saat menulis aplikasi dengan Estimator, Anda harus memisahkan pipeline input data dari model. Pemisahan ini menyederhanakan eksperimen dengan kumpulan data yang berbeda.

Menggunakan Estimator yang telah dibuat sebelumnya

Estimator yang dibuat sebelumnya memungkinkan Anda bekerja pada tingkat konseptual yang jauh lebih tinggi daripada TensorFlow API dasar. Anda tidak perlu lagi khawatir tentang membuat grafik atau sesi komputasi karena Estimator menangani semua "pemipaan" untuk Anda. Selain itu, Estimator siap pakai memungkinkan Anda bereksperimen dengan arsitektur model yang berbeda dengan hanya membuat sedikit perubahan kode. tf.estimator.DNNClassifier , misalnya, adalah class Estimator tf.estimator.DNNClassifier yang melatih model klasifikasi berdasarkan jaringan neural feed-forward yang padat.

Program TensorFlow yang mengandalkan Estimator siap pakai biasanya terdiri dari empat langkah berikut:

1. Tulis fungsi input

Misalnya, Anda dapat membuat satu fungsi untuk mengimpor set pelatihan dan fungsi lain untuk mengimpor set pengujian. Estimator mengharapkan masukan mereka diformat sebagai sepasang objek:

  • Kamus di mana kuncinya adalah nama fitur dan nilainya adalah Tensors (atau SparseTensors) yang berisi data fitur yang sesuai
  • Tensor yang berisi satu atau lebih label

input_fn harus mengembalikantf.data.Dataset yang menghasilkan pasangan dalam format itu.

Misalnya, kode berikut membuattf.data.Dataset dari file train.csv kumpulan data Titanic:

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.experimental.AUTOTUNE))
  return titanic_batches

input_fn dieksekusi dalam tf.Graph dan juga bisa langsung mengembalikan pasangan (features_dics, labels) berisi tensor grafik, tetapi ini rawan kesalahan di luar kasus sederhana seperti mengembalikan konstanta.

2. Tentukan kolom fitur.

Setiap tf.feature_column mengidentifikasi nama fitur, tipenya, dan tf.feature_column input apa pun.

Misalnya, cuplikan berikut membuat tiga kolom fitur.

  • Yang pertama menggunakan fitur age secara langsung sebagai input floating-point.
  • Yang kedua menggunakan fitur class sebagai input kategoris.
  • Yang ketiga menggunakan embark_town sebagai input kategoris, tetapi menggunakan hashing trick untuk menghindari kebutuhan untuk menghitung opsi, dan untuk menyetel jumlah opsi.

Untuk informasi lebih lanjut, lihat tutorial kolom fitur .

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3. Instantiate Estimator yang telah dibuat sebelumnya yang relevan.

Sebagai contoh, berikut adalah contoh Instansiasi dari Estimator LinearClassifier bernama LinearClassifier :

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp_2fgw1gd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Untuk informasi lebih lanjut, lihat tutorial pengklasifikasi linier .

4. Panggil metode pelatihan, evaluasi, atau inferensi.

Semua Estimator menyediakan metode train , evaluate , dan predict .

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:1481: Layer.add_variable (from tensorflow.python.keras.engine.base_layer_v1) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:112: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp_2fgw1gd/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmp_2fgw1gd/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.6098593.

result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-10-15T01:25:18Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp_2fgw1gd/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.63935s
INFO:tensorflow:Finished evaluation at 2020-10-15-01:25:19
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.7, accuracy_baseline = 0.603125, auc = 0.70968133, auc_precision_recall = 0.6162292, average_loss = 0.6068252, global_step = 100, label/mean = 0.396875, loss = 0.6068252, precision = 0.6962025, prediction/mean = 0.3867289, recall = 0.43307087
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmp_2fgw1gd/model.ckpt-100
accuracy : 0.7
accuracy_baseline : 0.603125
auc : 0.70968133
auc_precision_recall : 0.6162292
average_loss : 0.6068252
label/mean : 0.396875
loss : 0.6068252
precision : 0.6962025
prediction/mean : 0.3867289
recall : 0.43307087
global_step : 100

for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp_2fgw1gd/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [0.6824188]
logistic : [0.6642783]
probabilities : [0.33572164 0.6642783 ]
class_ids : [1]
classes : [b'1']
all_class_ids : [0 1]
all_classes : [b'0' b'1']

Manfaat Estimator yang dibuat sebelumnya

Estimator yang dibuat sebelumnya mengenkode praktik terbaik, memberikan manfaat berikut:

  • Praktik terbaik untuk menentukan di mana bagian berbeda dari grafik komputasi harus dijalankan, menerapkan strategi di satu mesin atau di kluster.
  • Praktik terbaik untuk penulisan acara (ringkasan) dan ringkasan yang berguna secara universal.

Jika Anda tidak menggunakan Estimator siap pakai, Anda harus menerapkan sendiri fitur sebelumnya.

Estimator Kustom

Inti dari setiap Estimator — baik yang dibuat sebelumnya atau khusus — adalah fungsi modelnya , model_fn , yang merupakan metode yang membuat grafik untuk pelatihan, evaluasi, dan prediksi. Saat Anda menggunakan Estimator yang dibuat sebelumnya, orang lain telah menerapkan fungsi model. Saat mengandalkan Estimator kustom, Anda harus menulis sendiri fungsi modelnya.

Buat Estimator dari model Keras

Anda dapat mengubah model Keras yang ada menjadi Estimator dengan tf.keras.estimator.model_to_estimator . Ini berguna jika Anda ingin memodernisasi kode model Anda, tetapi pipeline pelatihan Anda masih membutuhkan Estimator.

Buat instance model Keras MobileNet V2 dan kompilasi model dengan pengoptimal, kerugian, dan metrik untuk dilatih dengan:

impor tensorflow sebagai tf import tensorflow_datasets sebagai tfds

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

Buat Estimator dari model Keras yang telah dikompilasi. Status model awal model Keras dipertahankan di Estimator dibuat:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpy7tafocp
INFO:tensorflow:Using the Keras model provided.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py:220: set_learning_phase (from tensorflow.python.keras.backend) is deprecated and will be removed after 2020-10-11.
Instructions for updating:
Simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpy7tafocp', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Perlakukan Estimator turunan seperti yang Anda lakukan dengan Estimator lainnya.

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

Untuk berlatih, panggil fungsi kereta Estimator:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...

Warning:absl:1738 images were corrupted and were skipped

Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteMUGW8X/cats_vs_dogs-train.tfrecord
Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpy7tafocp/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpy7tafocp/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting from: /tmp/tmpy7tafocp/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting from: /tmp/tmpy7tafocp/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.68286127, step = 0

INFO:tensorflow:loss = 0.68286127, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpy7tafocp/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Loss for final step: 0.70231926.

INFO:tensorflow:Loss for final step: 0.70231926.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f1a6c4d4cf8>

Demikian pula, untuk mengevaluasi, panggil fungsi evaluasi Estimator:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:26Z

INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:26Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/tmpy7tafocp/model.ckpt-50

INFO:tensorflow:Restoring parameters from /tmp/tmpy7tafocp/model.ckpt-50

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Inference Time : 1.92025s

INFO:tensorflow:Inference Time : 1.92025s

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:28

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:28

INFO:tensorflow:Saving dict for global step 50: accuracy = 0.565625, global_step = 50, loss = 0.6713216

INFO:tensorflow:Saving dict for global step 50: accuracy = 0.565625, global_step = 50, loss = 0.6713216

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpy7tafocp/model.ckpt-50

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpy7tafocp/model.ckpt-50

{'accuracy': 0.565625, 'loss': 0.6713216, 'global_step': 50}

Untuk lebih jelasnya, silakan lihat dokumentasi untuk tf.keras.estimator.model_to_estimator .

Menyimpan pos pemeriksaan berbasis objek dengan Estimator

Estimator secara default menyimpan checkpoint dengan nama variabel daripada grafik objek yang dijelaskan dalam panduan Checkpoint . tf.train.Checkpoint akan membaca checkpoint berbasis nama, tetapi nama variabel dapat berubah saat memindahkan bagian model di luar model_fn Estimator. Untuk kompatibilitas ke depan, penyimpanan pos pemeriksaan berbasis objek membuatnya lebih mudah untuk melatih model di dalam Estimator dan kemudian menggunakannya di luar Estimator.

import tensorflow.compat.v1 as tf_compat
def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.

INFO:tensorflow:Using default config.

INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 4.5633583, step = 0

INFO:tensorflow:loss = 4.5633583, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Loss for final step: 37.95615.

INFO:tensorflow:Loss for final step: 37.95615.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f1a6c477630>

tf.train.Checkpoint kemudian dapat memuat checkpoint Estimator dari model_dir .

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

DisimpanModel dari Estimator

Estimator mengekspor SavedModels melalui tf.Estimator.export_saved_model .

input_column = tf.feature_column.numeric_column("x")

estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config.

INFO:tensorflow:Using default config.

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpn_8rzqza

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpn_8rzqza

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpn_8rzqza', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpn_8rzqza', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.6931472, step = 0

INFO:tensorflow:loss = 0.6931472, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpn_8rzqza/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Loss for final step: 0.35876164.

INFO:tensorflow:Loss for final step: 0.35876164.

<tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f1a6c448b00>

Untuk menyimpan Estimator Anda perlu membuat serving_input_receiver . Fungsi ini membangun bagian dari tf.Graph yang mengurai data mentah yang diterima oleh SavedModel.

Modul tf.estimator.export berisi fungsi untuk membantu membangun receivers ini.

Kode berikut membangun penerima, berdasarkan feature_columns , yang menerima penyangga protokol tf.Example , yang sering digunakan dengan tf-serving .

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']

INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']

INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']

INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Restoring parameters from /tmp/tmpn_8rzqza/model.ckpt-50

INFO:tensorflow:Restoring parameters from /tmp/tmpn_8rzqza/model.ckpt-50

INFO:tensorflow:Assets added to graph.

INFO:tensorflow:Assets added to graph.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:SavedModel written to: /tmp/tmptcppevt7/from_estimator/temp-1602725189/saved_model.pb

INFO:tensorflow:SavedModel written to: /tmp/tmptcppevt7/from_estimator/temp-1602725189/saved_model.pb

Anda juga dapat memuat dan menjalankan model itu, dari python:

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.43590346, 0.5640965 ]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2578045]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.5640965]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>}
{'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7984398 , 0.20156018]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.3765715]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2015602]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>}

tf.estimator.export.build_raw_serving_input_receiver_fn memungkinkan Anda membuat fungsi masukan yang menggunakan tensor mentah daripada tf.train.Example s.

Menggunakan tf.distribute.Strategy dengan Estimator (Dukungan terbatas)

Lihat panduan pelatihan Terdistribusi untuk info lebih lanjut.

tf.estimator adalah TensorFlow API pelatihan terdistribusi yang awalnya mendukung pendekatan server parameter asinkron. tf.estimator sekarang mendukung tf.distribute.Strategy . Jika Anda menggunakan tf.estimator , Anda dapat mengubah ke pelatihan terdistribusi dengan sedikit perubahan pada kode Anda. Dengan ini, pengguna Estimator sekarang dapat melakukan pelatihan terdistribusi sinkron pada beberapa GPU dan beberapa pekerja, serta menggunakan TPU. Namun, dukungan di Estimator ini terbatas. Lihat bagian Apa yang didukung sekarang di bawah untuk lebih jelasnya.

Penggunaan tf.distribute.Strategy dengan Estimator sedikit berbeda dari kasus Keras. Alih-alih menggunakan strategy.scope , sekarang kita meneruskan objek strategi ke RunConfig untuk Estimator.

Berikut adalah potongan kode yang menunjukkan ini dengan Estimator LinearRegressor dan MirroredStrategy :

mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
    feature_columns=[tf.feature_column.numeric_column('feats')],
    optimizer='SGD',
    config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)

INFO:tensorflow:Initializing RunConfig with distribution strategies.

INFO:tensorflow:Initializing RunConfig with distribution strategies.

INFO:tensorflow:Not using Distribute Coordinator.

INFO:tensorflow:Not using Distribute Coordinator.

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmphb70j0wf

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmphb70j0wf

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmphb70j0wf', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmphb70j0wf', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f1b28cfd4e0>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

Kami menggunakan Estimator yang dibuat sebelumnya di sini, tetapi kode yang sama juga berfungsi dengan Estimator kustom. train_distribute menentukan bagaimana pelatihan akan didistribusikan, dan eval_distribute menentukan bagaimana evaluasi akan didistribusikan. Ini adalah perbedaan lain dari Keras di mana kami menggunakan strategi yang sama untuk pelatihan dan evaluasi.

Sekarang kita dapat melatih dan mengevaluasi Estimator ini dengan fungsi input:

def input_fn():
  dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
  return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:339: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:339: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1b28ec8b70> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1b28ec8b70> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning: AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1b28ec8b70> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 1.0, step = 0

INFO:tensorflow:loss = 1.0, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...

INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmphb70j0wf/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...

INFO:tensorflow:Loss for final step: 2.877698e-13.

INFO:tensorflow:Loss for final step: 2.877698e-13.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1ad8062d08> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning:tensorflow:AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1ad8062d08> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

Warning: AutoGraph could not transform <function _combine_distributed_scaffold.<locals>.<lambda> at 0x7f1ad8062d08> and will run it as-is.
Cause: could not parse the source code:

      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:34Z

INFO:tensorflow:Starting evaluation at 2020-10-15T01:26:34Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/tmphb70j0wf/model.ckpt-10

INFO:tensorflow:Restoring parameters from /tmp/tmphb70j0wf/model.ckpt-10

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Inference Time : 0.23888s

INFO:tensorflow:Inference Time : 0.23888s

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:34

INFO:tensorflow:Finished evaluation at 2020-10-15-01:26:34

INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994

INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmphb70j0wf/model.ckpt-10

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmphb70j0wf/model.ckpt-10

{'average_loss': 1.4210855e-14,
 'label/mean': 1.0,
 'loss': 1.4210855e-14,
 'prediction/mean': 0.99999994,
 'global_step': 10}

Perbedaan lain yang harus disoroti di sini antara Estimator dan Keras adalah penanganan input. Di Keras, kami menyebutkan bahwa setiap kumpulan kumpulan data dibagi secara otomatis di beberapa replika. Namun, di Estimator, kami tidak melakukan pemisahan batch otomatis, atau secara otomatis membagi data ke seluruh pekerja yang berbeda. Anda memiliki kontrol penuh atas bagaimana Anda ingin data Anda didistribusikan ke seluruh pekerja dan perangkat, dan Anda harus memberikan input_fn untuk menentukan cara mendistribusikan data Anda.

input_fn Anda dipanggil sekali per pekerja, sehingga memberikan satu input_fn data per pekerja. Kemudian satu batch dari set data itu diumpankan ke satu replika pada pekerja itu, sehingga menggunakan N batch untuk N replika pada 1 pekerja. Dengan kata lain, kumpulan data yang dikembalikan oleh input_fn harus menyediakan kumpulan ukuran PER_REPLICA_BATCH_SIZE . Dan ukuran batch global untuk sebuah langkah dapat diperoleh sebagai PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync .

Saat melakukan pelatihan multi pekerja, Anda harus membagi data Anda di seluruh pekerja, atau mengacaknya dengan benih acak di masing-masing pekerja. Anda dapat melihat contoh bagaimana melakukan ini di Pelatihan Multi-pekerja dengan Estimator .

Dan serupa, Anda juga dapat menggunakan strategi multi pekerja dan server parameter. Kode tetap sama, tetapi Anda perlu menggunakan tf.estimator.train_and_evaluate , dan mengatur variabel lingkungan TF_CONFIG untuk setiap biner yang berjalan di cluster Anda.

Apa yang didukung sekarang?

Ada dukungan terbatas untuk pelatihan dengan Estimator yang menggunakan semua strategi kecuali Strategi TPUStrategy . Pelatihan dan evaluasi dasar seharusnya berfungsi, tetapi sejumlah fitur lanjutan seperti v1.train.Scaffold tidak. Mungkin juga ada sejumlah bug dalam integrasi ini. Saat ini, kami tidak berencana untuk secara aktif meningkatkan dukungan ini, melainkan berfokus pada Keras dan dukungan loop pelatihan khusus. Jika memungkinkan, Anda sebaiknya memilih menggunakan tf.distribute dengan API tersebut.

API Pelatihan MirroredStrategy Strategi TPUS MultiWorkerMirroredStrategy CentralStorageStrategy ParameterServerStrategy
Estimator API Dukungan Terbatas Tidak didukung Dukungan Terbatas Dukungan Terbatas Dukungan Terbatas

Contoh dan Tutorial

Berikut beberapa contoh yang menunjukkan penggunaan ujung ke ujung berbagai strategi dengan Estimator:

  1. Pelatihan Multi-pekerja dengan Estimator untuk melatih MNIST dengan banyak pekerja menggunakan MultiWorkerMirroredStrategy .
  2. Contoh ujung ke ujung untuk pelatihan multi pekerja di tensorflow / ekosistem menggunakan template Kubernetes. Contoh ini dimulai dengan model Keras dan mengubahnya menjadi Estimator menggunakan API tf.keras.estimator.model_to_estimator .
  3. Model ResNet50 resmi, yang dapat dilatih menggunakan MirroredStrategy atau MultiWorkerMirroredStrategy .