Trang này được dịch bởi Cloud Translation API.
Switch to English

Công cụ ước tính

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ tay

Tài liệu này giới thiệu tf.estimator — một API TensorFlow cấp cao. Công cụ ước tính gói gọn các hành động sau:

  • đào tạo
  • đánh giá
  • sự dự đoán
  • xuất khẩu để phục vụ

Bạn có thể sử dụng Công cụ ước tính được tạo sẵn mà chúng tôi cung cấp hoặc viết Công cụ ước tính tùy chỉnh của riêng bạn. Tất cả các Công cụ ước tính - dù được tạo sẵn hay tùy chỉnh - đều là các lớp dựa trên lớp tf.estimator.Estimator .

Để có ví dụ nhanh, hãy thử hướng dẫn về Công cụ ước tính . Để biết tổng quan về thiết kế API, hãy xem sách trắng .

Thiết lập

 pip install -q -U tensorflow_datasets
WARNING: You are using pip version 20.2.2; however, version 20.2.3 is available.
You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.

import tempfile
import os

import tensorflow as tf
import tensorflow_datasets as tfds

Ưu điểm

Tương tự như tf.keras.Model , công estimator là một trừu tượng hóa cấp mô hình. tf.estimator cung cấp một số khả năng hiện vẫn đang được phát triển cho tf.keras . Đó là:

  • Đào tạo dựa trên máy chủ tham số
  • Tích hợp đầy đủ TFX .

Công cụ ước tính Khả năng

Công cụ ước tính cung cấp những lợi ích sau:

  • Bạn có thể chạy các mô hình dựa trên Công cụ ước tính trên máy chủ cục bộ hoặc trên môi trường nhiều máy chủ phân tán mà không cần thay đổi mô hình của bạn. Hơn nữa, bạn có thể chạy các mô hình dựa trên Công cụ ước tính trên CPU, GPU hoặc TPU mà không cần mã hóa mô hình của bạn.
  • Công cụ ước tính cung cấp một vòng đào tạo phân tán an toàn kiểm soát cách thức và thời điểm:
    • tải dữ liệu
    • xử lý các trường hợp ngoại lệ
    • tạo các tệp điểm kiểm tra và khôi phục sau lỗi
    • lưu tóm tắt cho TensorBoard

Khi viết ứng dụng với Công cụ ước tính, bạn phải tách đường ống nhập dữ liệu ra khỏi mô hình. Sự tách biệt này đơn giản hóa các thử nghiệm với các tập dữ liệu khác nhau.

Sử dụng Công cụ ước tính được tạo sẵn

Công cụ ước tính được tạo sẵn cho phép bạn làm việc ở cấp độ khái niệm cao hơn nhiều so với các API TensorFlow cơ sở. Bạn không còn phải lo lắng về việc tạo biểu đồ hoặc phiên tính toán vì Công cụ ước tính xử lý tất cả "hệ thống ống nước" cho bạn. Hơn nữa, Công cụ ước tính được tạo trước cho phép bạn thử nghiệm với các kiến ​​trúc mô hình khác nhau bằng cách chỉ thực hiện các thay đổi mã tối thiểu. Ví dụ: tf.estimator.DNNClassifier là một lớp Ước tính được tạo sẵn để đào tạo các mô hình phân loại dựa trên các mạng nơ-ron truyền về phía trước, dày đặc.

Chương trình TensorFlow dựa trên Công cụ ước tính được tạo sẵn thường bao gồm bốn bước sau:

1. Viết một hàm đầu vào

Ví dụ: bạn có thể tạo một hàm để nhập tập huấn luyện và một hàm khác để nhập tập kiểm tra. Các nhà ước tính mong đợi đầu vào của họ được định dạng dưới dạng một cặp đối tượng:

  • Một từ điển trong đó các khóa là tên tính năng và giá trị là Tensors (hoặc SparseTensors) chứa dữ liệu tính năng tương ứng
  • Một Tensor chứa một hoặc nhiều nhãn

input_fn phải trả về một tf.data.Dataset tạo ra các cặp ở định dạng đó.

Ví dụ: đoạn mã sau tạo một tf.data.Dataset từ tệp train.csv của tập dữ liệu Titanic:

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.experimental.AUTOTUNE))
  return titanic_batches

input_fn được thực thi trong một tf.Graph và cũng có thể trả về trực tiếp một cặp (features_dics, labels) chứa các tensors đồ thị, nhưng điều này dễ xảy ra lỗi ngoài các trường hợp đơn giản như trả về hằng số.

2. Xác định các cột tính năng.

Mỗi tf.feature_column xác định tên đối tượng, loại của nó và bất kỳ quá trình xử lý trước đầu vào nào.

Ví dụ: đoạn mã sau tạo ba cột tính năng.

  • Đầu tiên sử dụng trực tiếp tính năng age làm đầu vào dấu phẩy động.
  • Thứ hai sử dụng tính năng class làm đầu vào phân loại.
  • Loại thứ ba sử dụng embark_town như một đầu vào phân loại, nhưng sử dụng thủ hashing trick để tránh phải liệt kê các tùy chọn và đặt số lượng tùy chọn.

Để biết thêm thông tin, hãy xem hướng dẫn về cột tính năng .

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3. Khởi tạo Công cụ ước tính được tạo sẵn có liên quan.

Ví dụ: đây là một bản mô tả mẫu của Công cụ ước tính được tạo sẵn có tên là LinearClassifier :

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpjm3x59ce', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Để biết thêm thông tin, hãy xem hướng dẫn phân loại tuyến tính .

4. Gọi một phương pháp đào tạo, đánh giá hoặc suy luận.

Tất cả các Công cụ ước tính đều cung cấp các phương pháp train , evaluatepredict .

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:1481: Layer.add_variable (from tensorflow.python.keras.engine.base_layer_v1) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:112: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpjm3x59ce/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpjm3x59ce/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.5892383.

retef4edd7
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-09-23T01:21:41Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpjm3x59ce/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.74020s
INFO:tensorflow:Finished evaluation at 2020-09-23-01:21:42
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.6875, accuracy_baseline = 0.609375, auc = 0.73963076, auc_precision_recall = 0.64400905, average_loss = 0.59503603, global_step = 100, label/mean = 0.390625, loss = 0.59503603, precision = 0.74509805, prediction/mean = 0.31810525, recall = 0.304
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpjm3x59ce/model.ckpt-100
accuracy : 0.6875
accuracy_baseline : 0.609375
auc : 0.73963076
auc_precision_recall : 0.64400905
average_loss : 0.59503603
label/mean : 0.390625
loss : 0.59503603
precision : 0.74509805
prediction/mean : 0.31810525
recall : 0.304
global_step : 100

for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpjm3x59ce/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [-1.3046792]
logistic : [0.21337856]
probabilities : [0.78662145 0.21337858]
class_ids : [0]
classes : [b'0']
all_class_ids : [0 1]
all_classes : [b'0' b'1']

Lợi ích của Công cụ ước tính tạo sẵn

Công cụ ước tính được tạo sẵn mã hóa các phương pháp hay nhất, mang lại những lợi ích sau:

  • Các phương pháp hay nhất để xác định vị trí các phần khác nhau của biểu đồ tính toán sẽ chạy, triển khai các chiến lược trên một máy hoặc trên một cụm.
  • Các phương pháp hay nhất để viết (tóm tắt) sự kiện và các bản tóm tắt hữu ích trên toàn cầu.

Nếu bạn không sử dụng Công cụ ước tính được tạo sẵn, bạn phải tự triển khai các tính năng trước đó.

Công cụ ước tính tùy chỉnh

Trọng tâm của mọi Công cụ ước tính — cho dù được tạo sẵn hay tùy chỉnh — là chức năng mô hình của nó, là một phương pháp xây dựng biểu đồ để đào tạo, đánh giá và dự đoán. Khi bạn đang sử dụng Công cụ ước tính được tạo sẵn, người khác đã triển khai chức năng mô hình. Khi dựa vào Công cụ ước tính tùy chỉnh, bạn phải tự viết hàm mô hình.

Vì vậy, quy trình làm việc được đề xuất là:

  1. Giả sử tồn tại một Công cụ ước tính được tạo sẵn phù hợp, hãy sử dụng nó để xây dựng mô hình đầu tiên của bạn và sử dụng kết quả của nó để thiết lập đường cơ sở.
  2. Xây dựng và kiểm tra quy trình tổng thể của bạn, bao gồm tính toàn vẹn và độ tin cậy của dữ liệu bằng Công cụ ước tính được tạo trước này.
  3. Nếu có sẵn Công cụ ước tính tạo sẵn thay thế phù hợp, hãy chạy thử nghiệm để xác định Công cụ ước tính tạo sẵn nào tạo ra kết quả tốt nhất.
  4. Có thể, cải thiện hơn nữa mô hình của bạn bằng cách xây dựng Công cụ ước tính tùy chỉnh của riêng bạn.

Tạo Công cụ ước tính từ mô hình Keras

Bạn có thể chuyển đổi các mô hình Keras hiện có thành Công cụ ước tính với tf.keras.estimator.model_to_estimator . Làm như vậy cho phép mô hình Keras của bạn tiếp cận các điểm mạnh của Công cụ ước tính, chẳng hạn như đào tạo phân tán.

Khởi tạo mô hình Keras MobileNet V2 và biên dịch mô hình với trình tối ưu hóa, tổn thất và số liệu để đào tạo với:

nhập tensorflow dưới dạng tf nhập tensorflow_datasets dưới dạng tfds

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 1s 0us/step

Tạo Estimator từ mô hình Keras đã biên dịch. Trạng thái mô hình ban đầu của mô hình Keras được giữ nguyên trong Estimator đã tạo:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp9s6ijizi
INFO:tensorflow:Using the Keras model provided.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py:220: set_learning_phase (from tensorflow.python.keras.backend) is deprecated and will be removed after 2020-10-11.
Instructions for updating:
Simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp9s6ijizi', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

Xử lý Estimator dẫn xuất như bạn làm với bất kỳ Estimator nào khác.

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

Để đào tạo, hãy gọi chức năng đào tạo của Công cụ ước tính:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
Downloading and preparing dataset cats_vs_dogs/4.0.0 (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0...

Warning:absl:1738 images were corrupted and were skipped

Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0.incompleteVPYUDE/cats_vs_dogs-train.tfrecord
Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data.
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp9s6ijizi/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp9s6ijizi/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})

INFO:tensorflow:Warm-starting from: /tmp/tmp9s6ijizi/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting from: /tmp/tmp9s6ijizi/keras/keras_model.ckpt

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Warm-started 158 variables.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp9s6ijizi/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp9s6ijizi/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.7802818, step = 0

INFO:tensorflow:loss = 0.7802818, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp9s6ijizi/model.ckpt.

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp9s6ijizi/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Loss for final step: 0.7024657.

INFO:tensorflow:Loss for final step: 0.7024657.

<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f6734021e10>

Tương tự, để đánh giá, hãy gọi chức năng đánh giá của Công cụ ước tính:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Starting evaluation at 2020-09-23T01:22:32Z

INFO:tensorflow:Starting evaluation at 2020-09-23T01:22:32Z

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Restoring parameters from /tmp/tmp9s6ijizi/model.ckpt-50

INFO:tensorflow:Restoring parameters from /tmp/tmp9s6ijizi/model.ckpt-50

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [1/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [2/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [3/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [4/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [5/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [6/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [7/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [8/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [9/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Evaluation [10/10]

INFO:tensorflow:Inference Time : 2.16132s

INFO:tensorflow:Inference Time : 2.16132s

INFO:tensorflow:Finished evaluation at 2020-09-23-01:22:34

INFO:tensorflow:Finished evaluation at 2020-09-23-01:22:34

INFO:tensorflow:Saving dict for global step 50: accuracy = 0.490625, global_step = 50, loss = 0.69025326

INFO:tensorflow:Saving dict for global step 50: accuracy = 0.490625, global_step = 50, loss = 0.69025326

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmp9s6ijizi/model.ckpt-50

INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmp9s6ijizi/model.ckpt-50

{'accuracy': 0.490625, 'loss': 0.69025326, 'global_step': 50}

Để biết thêm chi tiết, vui lòng tham khảo tài liệu cho tf.keras.estimator.model_to_estimator .

Đã lưu Mô hình từ Công cụ ước tính

Công cụ ước tính xuất SavedModels thông qua tf.Estimator.export_saved_model .

input_column = tf.feature_column.numeric_column("x")

estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config.

INFO:tensorflow:Using default config.

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpvkv001gk

Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpvkv001gk

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpvkv001gk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpvkv001gk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Create CheckpointSaverHook.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Graph was finalized.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Done running local_init_op.

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpvkv001gk/model.ckpt.

INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpvkv001gk/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...

INFO:tensorflow:loss = 0.6931472, step = 0

INFO:tensorflow:loss = 0.6931472, step = 0

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpvkv001gk/model.ckpt.

INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpvkv001gk/model.ckpt.

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...

INFO:tensorflow:Loss for final step: 0.5307944.

INFO:tensorflow:Loss for final step: 0.5307944.

<tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f667444d780>

Để lưu Estimator bạn cần tạo một serving_input_receiver . Hàm này xây dựng một phần của tf.Graph phân tích cú pháp dữ liệu thô mà SavedModel nhận được.

Mô-đun tf.estimator.export chứa các chức năng giúp xây dựng các receivers này.

Đoạn mã sau xây dựng một bộ thu, dựa trên feature_columns , chấp nhận các bộ đệm giao thức tf.Example được tuần tự hóa, thường được sử dụng với phân phối tf .

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Calling model_fn.

INFO:tensorflow:Done calling model_fn.

INFO:tensorflow:Done calling model_fn.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.

INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']

INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']

INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']

INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']

INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Train: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Signatures INCLUDED in export for Eval: None

INFO:tensorflow:Restoring parameters from /tmp/tmpvkv001gk/model.ckpt-50

INFO:tensorflow:Restoring parameters from /tmp/tmpvkv001gk/model.ckpt-50

INFO:tensorflow:Assets added to graph.

INFO:tensorflow:Assets added to graph.

INFO:tensorflow:No assets to write.

INFO:tensorflow:No assets to write.

INFO:tensorflow:SavedModel written to: /tmp/tmpf62r0bly/from_estimator/temp-1600824155/saved_model.pb

INFO:tensorflow:SavedModel written to: /tmp/tmpf62r0bly/from_estimator/temp-1600824155/saved_model.pb

Bạn cũng có thể tải và chạy mô hình đó, từ python:

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.41081154, 0.58918846]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.36061144]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.58918846]], dtype=float32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>}
{'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7386056 , 0.26139432]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.038734]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.26139435]], dtype=float32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>}

tf.estimator.export.build_raw_serving_input_receiver_fn cho phép bạn tạo các hàm đầu vào lấy tensors thô thay vì tf.train.Example . tf.train.Example s.