Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

Estymatory

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

Ten dokument wprowadza tf.estimator - interfejs API TensorFlow wysokiego poziomu. Estymatory obejmują następujące działania:

  • trening
  • ocena
  • Prognoza
  • eksport do serwowania

Możesz skorzystać z gotowych estymatorów, które udostępniamy, lub napisać własne, niestandardowe estymatory. Wszystkie estymatory - gotowe lub niestandardowe - są klasami opartymi na klasie tf.estimator.Estimator .

Aby uzyskać szybki przykład, wypróbuj samouczki Estimatora . Omówienie projektu interfejsu API znajduje się w białej księdze .

Ustawiać

 pip install -q -U tensorflow_datasets
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

import tempfile
import os

import tensorflow as tf
import tensorflow_datasets as tfds

Zalety

Podobnie jak tf.keras.Model , estimator jest abstrakcją na poziomie modelu. tf.estimator zapewnia pewne możliwości, które są obecnie nadal rozwijane dla tf.keras . To są:

  • Trening oparty na serwerze parametrów
  • Pełna integracja z TFX .

Możliwości estymatorów

Estymatory zapewniają następujące korzyści:

  • Modele oparte na estymatorze można uruchamiać na hoście lokalnym lub w rozproszonym środowisku wieloserwerowym bez zmiany modelu. Ponadto możesz uruchamiać modele oparte na estymatorze na procesorach, procesorach graficznych lub TPU bez ponownego kodowania modelu.
  • Estymatory zapewniają bezpieczną, rozproszoną pętlę szkoleniową, która kontroluje, jak i kiedy:
    • ładowanie danych
    • obsługi wyjątków
    • tworzyć pliki punktów kontrolnych i odzyskiwać po awariach
    • zapisz podsumowania dla TensorBoard

Podczas pisania aplikacji za pomocą estymatorów należy oddzielić potok wprowadzania danych od modelu. To rozdzielenie upraszcza eksperymenty z różnymi zestawami danych.

Korzystanie z gotowych estymatorów

Gotowe estymatory umożliwiają pracę na znacznie wyższym poziomie koncepcyjnym niż podstawowe interfejsy API TensorFlow. Nie musisz już martwić się o tworzenie wykresów obliczeniowych lub sesji, ponieważ estymatory zajmują się całą „hydrauliką” za Ciebie. Ponadto gotowe estymatory pozwalają eksperymentować z różnymi architekturami modeli, wprowadzając jedynie minimalne zmiany w kodzie. tf.estimator.DNNClassifier przykład tf.estimator.DNNClassifier , która szkoli modele klasyfikacji oparte na gęstych, sprzężonych sieciach neuronowych.

Program TensorFlow oparty na gotowym estymatorze składa się zwykle z następujących czterech kroków:

1. Napisz funkcje wejściowe

Na przykład można utworzyć jedną funkcję do importowania zestawu uczącego i inną funkcję do importowania zestawu testowego. Estymatorzy oczekują, że ich dane wejściowe zostaną sformatowane jako para obiektów:

  • Słownik, w którym klucze są nazwami funkcji, a wartościami są tensory (lub SparseTensors), zawierające odpowiednie dane funkcji
  • Tensor zawierający jedną lub więcej etykiet

input_fn powinien zwrócić tf.data.Dataset który zwraca pary w tym formacie.

Na przykład, następujący kod buduje tf.data.Dataset z Titanica danej jednostki train.csv pliku:

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.experimental.AUTOTUNE))
  return titanic_batches

input_fn jest wykonywany w tf.Graph i może również bezpośrednio zwracać parę (features_dics, labels) zawierającą tensory wykresu, ale jest to podatne na błędy poza prostymi przypadkami, takimi jak zwracanie stałych.

2. Zdefiniuj kolumny funkcji.

Każda tf.feature_column identyfikuje nazwę funkcji, jej typ i wszelkie wstępne przetwarzanie danych wejściowych.

Na przykład poniższy fragment kodu tworzy trzy kolumny funkcji.

  • Pierwsza wykorzystuje funkcję age bezpośrednio jako zmiennoprzecinkowe dane wejściowe.
  • Drugi wykorzystuje cechę class jako kategoryczne dane wejściowe.
  • Trzeci używa embark_town jako kategorycznego wejścia, ale używa hashing trick aby uniknąć konieczności wyliczania opcji i ustawiania liczby opcji.

Aby uzyskać więcej informacji, zobacz samouczek dotyczący kolumn elementowych .

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3. Utwórz instancję odpowiedniego, przygotowanego wcześniej Estymatora.

Na przykład, oto przykładowa instancja LinearClassifier estymatora o nazwie LinearClassifier :

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpjm3x59ce', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Aby uzyskać więcej informacji, zobacz samouczek dotyczący klasyfikatora liniowego .

4. Nazwij metodę szkolenia, oceny lub wnioskowania.

Wszystkie estymatory zapewniają metody train , evaluate i predict .

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:1481: Layer.add_variable (from tensorflow.python.keras.engine.base_layer_v1) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:112: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpjm3x59ce/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpjm3x59ce/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.5892383.

result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-09-23T01:21:41Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpjm3x59ce/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.74020s
INFO:tensorflow:Finished evaluation at 2020-09-23-01:21:42
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.6875, accuracy_baseline = 0.609375, auc = 0.73963076, auc_precision_recall = 0.64400905, average_loss = 0.59503603, global_step = 100, label/mean = 0.390625, loss = 0.59503603, precision = 0.74509805, prediction/mean = 0.31810525, recall = 0.304
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpjm3x59ce/model.ckpt-100
accuracy : 0.6875
accuracy_baseline : 0.609375
auc : 0.73963076
auc_precision_recall : 0.64400905
average_loss : 0.59503603
label/mean : 0.390625
loss : 0.59503603
precision : 0.74509805
prediction/mean : 0.31810525
recall : 0.304
global_step : 100

for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpjm3x59ce/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [-1.3046792]
logistic : [0.21337856]
probabilities : [0.78662145 0.21337858]
class_ids : [0]
classes : [b'0']
all_class_ids : [0 1]
all_classes : [b'0' b'1']

Korzyści z gotowych estymatorów

Gotowe estymatory kodują sprawdzone metody, zapewniając następujące korzyści:

  • Najlepsze praktyki dotyczące określania, gdzie powinny działać różne części wykresu obliczeniowego, implementowania strategii na pojedynczym komputerze lub w klastrze.
  • Najlepsze praktyki dotyczące pisania wydarzeń (podsumowań) i ogólnie użytecznych podsumowań.

Jeśli nie używasz gotowych estymatorów, musisz samodzielnie wdrożyć powyższe funkcje.

Estymatory niestandardowe

Sercem każdego Estymatora - gotowego lub niestandardowego - jest jego funkcja modelu , czyli metoda budująca wykresy do treningu, oceny i prognozowania. Kiedy używasz gotowego Estymatora, ktoś inny zaimplementował już funkcję modelu. W przypadku korzystania z niestandardowego Estymatora musisz samodzielnie napisać funkcję modelu.

Dlatego zalecany przepływ pracy to:

  1. Zakładając, że istnieje odpowiedni, gotowy estymator, użyj go do zbudowania pierwszego modelu i użyj jego wyników do ustalenia linii bazowej.
  2. Zbuduj i przetestuj swój ogólny potok, w tym integralność i niezawodność danych, za pomocą tego gotowego kalkulatora.
  3. Jeśli dostępne są odpowiednie alternatywne, gotowe estymatory, przeprowadź eksperymenty, aby określić, który z gotowych estymatorów daje najlepsze wyniki.
  4. Ewentualnie ulepszaj swój model, tworząc własny, niestandardowy Estimator.

Utwórz estymator z modelu Keras

Możesz przekonwertować istniejące modele Keras na Estymatory za pomocą tf.keras.estimator.model_to_estimator . Dzięki temu model Keras uzyska dostęp do mocnych stron Estymatora, takich jak rozproszone szkolenie.

Utwórz wystąpienie modelu Keras MobileNet V2 i skompiluj model z optymalizatorem, stratami i metrykami, aby trenować z:

importuj tensorflow jako tf importuj tensorflow_datasets jako tfds

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 1s 0us/step

Utwórz Estimator ze skompilowanego modelu Keras. Początkowy stan modelu modelu Keras jest zachowywany w utworzonym Estimator :

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp9s6ijizi
INFO:tensorflow:Using the Keras model provided.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py:220: set_learning_phase (from tensorflow.python.keras.backend) is deprecated and will be removed after 2020-10-11.
Instructions for updating:
Simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp9s6ijizi', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Traktuj wyprowadzony Estimator tak samo, jak każdy inny Estimator .

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

Aby trenować, wywołaj funkcję pociągu Estymatora:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...

Warning:absl:1738 images were corrupted and were skipped

Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteVPYUDE/cats_vs_dogs-train.tfrecord
Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp9s6ijizi/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp9s6ijizi/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting from: /tmp/tmp9s6ijizi/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting from: /tmp/tmp9s6ijizi/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp9s6ijizi/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp9s6ijizi/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.7802818, step = 0

INFO:tensorflow:loss = 0.7802818, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp9s6ijizi/model.ckpt.

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp9s6ijizi/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Loss for final step: 0.7024657.

INFO:tensorflow:Loss for final step: 0.7024657.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f6734021e10>

Podobnie, aby ocenić, wywołaj funkcję oceny estymatora:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Starting evaluation at 2020-09-23T01:22:32Z

INFO:tensorflow:Starting evaluation at 2020-09-23T01:22:32Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/tmp9s6ijizi/model.ckpt-50

INFO:tensorflow:Restoring parameters from /tmp/tmp9s6ijizi/model.ckpt-50

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Inference Time : 2.16132s

INFO:tensorflow:Inference Time : 2.16132s

INFO:tensorflow:Finished evaluation at 2020-09-23-01:22:34

INFO:tensorflow:Finished evaluation at 2020-09-23-01:22:34

INFO:tensorflow:Saving dict for global step 50: accuracy = 0.490625, global_step = 50, loss = 0.69025326

INFO:tensorflow:Saving dict for global step 50: accuracy = 0.490625, global_step = 50, loss = 0.69025326

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmp9s6ijizi/model.ckpt-50

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmp9s6ijizi/model.ckpt-50

{'accuracy': 0.490625, 'loss': 0.69025326, 'global_step': 50}

Aby uzyskać więcej informacji, zapoznaj się z dokumentacją dotyczącą tf.keras.estimator.model_to_estimator .

SavedModels from Estimators

Estymatory eksportują SavedModels przez tf.Estimator.export_saved_model .

input_column = tf.feature_column.numeric_column("x")

estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config.

INFO:tensorflow:Using default config.

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpvkv001gk

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpvkv001gk

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpvkv001gk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpvkv001gk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpvkv001gk/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpvkv001gk/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.6931472, step = 0

INFO:tensorflow:loss = 0.6931472, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpvkv001gk/model.ckpt.

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpvkv001gk/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Loss for final step: 0.5307944.

INFO:tensorflow:Loss for final step: 0.5307944.

<tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f667444d780>

Aby zapisać narzędzie Estimator , musisz utworzyć serving_input_receiver . Ta funkcja buduje część tf.Graph która analizuje surowe dane odebrane przez SavedModel.

Moduł tf.estimator.export zawiera funkcje ułatwiające tworzenie tych receivers .

Poniższy kod tworzy odbiornik w oparciu o tf.Example feature_columns , który akceptuje serializowane tf.Example bufory protokołów, które są często używane z obsługą tf .

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']

INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']

INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']

INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Restoring parameters from /tmp/tmpvkv001gk/model.ckpt-50

INFO:tensorflow:Restoring parameters from /tmp/tmpvkv001gk/model.ckpt-50

INFO:tensorflow:Assets added to graph.

INFO:tensorflow:Assets added to graph.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:SavedModel written to: /tmp/tmpf62r0bly/from_estimator/temp-1600824155/saved_model.pb

INFO:tensorflow:SavedModel written to: /tmp/tmpf62r0bly/from_estimator/temp-1600824155/saved_model.pb

Możesz także załadować i uruchomić ten model z Pythona:

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.41081154, 0.58918846]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.36061144]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.58918846]], dtype=float32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>}
{'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7386056 , 0.26139432]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.038734]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.26139435]], dtype=float32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>}

tf.estimator.export.build_raw_serving_input_receiver_fn umożliwia tworzenie funkcji wejściowych, które przyjmują surowe tensory zamiast tf.train.Example . tf.train.Example s.