![]() | ![]() | ![]() | ![]() |
이 문서는 고수준 TensorFlow API 인 tf.estimator
소개합니다. 에스티 메이터는 다음 작업을 캡슐화합니다.
- 훈련
- 평가
- 예측
- 서빙을 위해 내보내기
TensorFlow는 미리 만들어진 여러 에스티 메이터를 구현합니다. 커스텀 추정기는 여전히 지원되지만 주로 하위 호환성 측정으로 사용됩니다. 사용자 정의 추정기는 새 코드에 사용해서는 안됩니다 . 모든 tf.estimator.Estimator
사전 제작 또는 사용자 지정)는 tf.estimator.Estimator
클래스를 기반으로하는 클래스입니다.
간단한 예를 보려면 Estimator 튜토리얼을 참조하십시오 . API 설계에 대한 개요는 백서를 확인하십시오.
설정
pip install -q -U tensorflow_datasets
import tempfile
import os
import tensorflow as tf
import tensorflow_datasets as tfds
장점
tf.keras.Model
과 유사하게 estimator
는 모델 수준의 추상화입니다. tf.estimator
는 현재 tf.keras
용으로 개발중인 몇 가지 기능을 제공합니다. 이것들은:
- 매개 변수 서버 기반 교육
- 완전한 TFX 통합
에스티 메이터 기능
에스티 메이터는 다음과 같은 이점을 제공합니다.
- 모델을 변경하지 않고 로컬 호스트 또는 분산 다중 서버 환경에서 Estimator 기반 모델을 실행할 수 있습니다. 또한 모델을 다시 코딩하지 않고도 CPU, GPU 또는 TPU에서 Estimator 기반 모델을 실행할 수 있습니다.
- 에스티 메이터는 다음과 같은 방법과시기를 제어하는 안전한 분산 학습 루프를 제공합니다.
- 데이터로드
- 예외 처리
- 체크 포인트 파일 생성 및 장애 복구
- TensorBoard에 대한 요약 저장
에스티 메이터로 애플리케이션을 작성할 때 모델에서 데이터 입력 파이프 라인을 분리해야합니다. 이러한 분리는 다른 데이터 세트를 사용한 실험을 단순화합니다.
미리 만들어진 에스티 메이터 사용
사전 제작 된 에스티 메이터를 사용하면 기본 TensorFlow API보다 훨씬 더 높은 개념 수준에서 작업 할 수 있습니다. 에스티 메이터가 모든 "배관"을 처리하므로 더 이상 계산 그래프 또는 세션 생성에 대해 걱정할 필요가 없습니다. 또한 사전 제작 된 에스티 메이터를 사용하면 최소한의 코드 변경 만 수행하여 다양한 모델 아키텍처를 실험 할 수 있습니다. 예를 들어 tf.estimator.DNNClassifier
는 밀집된 피드 포워드 신경망을 기반으로 분류 모델을 훈련하는 사전 제작 된 Estimator 클래스입니다.
사전 제작 된 Estimator에 의존하는 TensorFlow 프로그램은 일반적으로 다음 4 단계로 구성됩니다.
1. 입력 함수 작성
예를 들어 훈련 세트를 가져 오는 함수 하나와 테스트 세트를 가져 오는 다른 함수를 만들 수 있습니다. 에스티 메이터는 입력이 한 쌍의 객체로 형식화 될 것으로 예상합니다.
- 키가 기능 이름이고 값이 해당 기능 데이터를 포함하는 Tensor (또는 SparseTensors) 인 사전
- 하나 이상의 레이블을 포함하는 Tensor
input_fn
은 해당 형식의 쌍을 생성하는tf.data.Dataset
을 반환해야합니다.
예를 들어, 다음 코드는 빌드tf.data.Dataset
타이타닉 데이터 세트의에서 train.csv
파일을 :
def train_input_fn():
titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic = tf.data.experimental.make_csv_dataset(
titanic_file, batch_size=32,
label_name="survived")
titanic_batches = (
titanic.cache().repeat().shuffle(500)
.prefetch(tf.data.AUTOTUNE))
return titanic_batches
input_fn
A의 실행 tf.Graph
또한 직접 돌아갈 수 (features_dics, labels)
그래프 텐서를 포함하는 한 쌍의, 그러나 이것은 리턴 상수 같은 간단한 경우 에러 유발 밖이다.
2. 기능 열을 정의합니다.
각 tf.feature_column
은 기능 이름, 유형 및 모든 입력 전처리를 식별합니다.
예를 들어 다음 스 니펫은 3 개의 특성 열을 만듭니다.
- 첫 번째는
age
기능을 부동 소수점 입력으로 직접 사용합니다. - 두 번째는
class
기능을 범주 입력으로 사용합니다. - 세 번째는
embark_town
을 범주 입력으로 사용하지만hashing trick
을 사용하여 옵션을 열거하고 옵션 수를 설정할 필요가 없습니다.
age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third'])
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)
3. 관련 사전 제작 된 Estimator를 인스턴스화합니다.
예를 들어, 다음은 LinearClassifier
라는 사전 제작 된 Estimator의 샘플 인스턴스화입니다.
model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=[embark, cls, age],
n_classes=2
)
INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpu27sw9ie', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
자세한 내용은 선형 분류기 튜토리얼을 참조하십시오 .
4. 교육, 평가 또는 추론 방법을 호출합니다.
모든 에스티 메이터는 train
, evaluate
및 predict
방법을 제공합니다.
model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv 32768/30874 [===============================] - 0s 0us/step INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1727: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead. warnings.warn('`layer.add_variable` is deprecated and ' Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:134: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpu27sw9ie/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100... INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpu27sw9ie/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100... INFO:tensorflow:Loss for final step: 0.62258995.
result = model.evaluate(train_input_fn, steps=10)
for key, value in result.items():
print(key, ":", value)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-01-08T02:56:30Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpu27sw9ie/model.ckpt-100 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.67613s INFO:tensorflow:Finished evaluation at 2021-01-08-02:56:31 INFO:tensorflow:Saving dict for global step 100: accuracy = 0.715625, accuracy_baseline = 0.60625, auc = 0.7403657, auc_precision_recall = 0.6804854, average_loss = 0.5836128, global_step = 100, label/mean = 0.39375, loss = 0.5836128, precision = 0.739726, prediction/mean = 0.34897345, recall = 0.42857143 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpu27sw9ie/model.ckpt-100 accuracy : 0.715625 accuracy_baseline : 0.60625 auc : 0.7403657 auc_precision_recall : 0.6804854 average_loss : 0.5836128 label/mean : 0.39375 loss : 0.5836128 precision : 0.739726 prediction/mean : 0.34897345 recall : 0.42857143 global_step : 100
for pred in model.predict(train_input_fn):
for key, value in pred.items():
print(key, ":", value)
break
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpu27sw9ie/model.ckpt-100 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. logits : [-0.73942876] logistic : [0.32312906] probabilities : [0.6768709 0.3231291] class_ids : [0] classes : [b'0'] all_class_ids : [0 1] all_classes : [b'0' b'1']
미리 만들어진 에스티 메이터의 이점
사전 제작 된 에스티 메이터는 모범 사례를 인코딩하여 다음과 같은 이점을 제공합니다.
- 계산 그래프의 다른 부분이 실행되어야하는 위치를 결정하고 단일 머신 또는 클러스터에서 전략을 구현하는 모범 사례입니다.
- 이벤트 (요약) 작성 및 보편적으로 유용한 요약에 대한 모범 사례.
미리 만들어진 에스티 메이터를 사용하지 않는 경우 이전 기능을 직접 구현해야합니다.
커스텀 에스티 메이터
사전 제작이든 맞춤형이든 모든 Estimator의 핵심은 학습, 평가 및 예측을위한 그래프를 작성하는 방법 인 모델 함수 인 model_fn
입니다. 미리 만들어진 Estimator를 사용할 때 다른 사람이 이미 모델 기능을 구현 한 것입니다. 커스텀 에스티 메이터에 의존하는 경우 모델 함수를 직접 작성해야합니다.
Keras 모델에서 에스티 메이터 생성
tf.keras.estimator.model_to_estimator
를 사용하여 기존 tf.keras.estimator.model_to_estimator
모델을 tf.keras.estimator.model_to_estimator
로 변환 할 수 있습니다. 이는 모델 코드를 현대화하려는 경우 유용하지만 훈련 파이프 라인에는 여전히 에스티 메이터가 필요합니다.
Keras MobileNet V2 모델을 인스턴스화하고 최적화 프로그램, 손실 및 메트릭을 사용하여 모델을 컴파일하여 다음과 같이 학습합니다.
keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False
estimator_model = tf.keras.Sequential([
keras_mobilenet_v2,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1)
])
# Compile the model
estimator_model.compile(
optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5 9412608/9406464 [==============================] - 0s 0us/step
컴파일 된 Keras 모델에서 Estimator
를 만듭니다. Keras 모델의 초기 모델 상태는 생성 된 Estimator
에서 유지됩니다.
est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpeaonpwe8 INFO:tensorflow:Using the Keras model provided. /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/backend.py:434: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model. warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and ' INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpeaonpwe8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
파생 된 Estimator
를 다른 Estimator
와 마찬가지로 취급하십시오.
IMG_SIZE = 160 # All images will be resized to 160x160
def preprocess(image, label):
image = tf.cast(image, tf.float32)
image = (image/127.5) - 1
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label
def train_input_fn(batch_size):
data = tfds.load('cats_vs_dogs', as_supervised=True)
train_data = data['train']
train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
return train_data
훈련하려면 Estimator의 훈련 함수를 호출하십시오.
est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
Downloading and preparing dataset 786.68 MiB (download: 786.68 MiB, generated: Unknown size, total: 786.68 MiB) to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0... Warning:absl:1738 images were corrupted and were skipped Dataset cats_vs_dogs downloaded and prepared to /home/kbuilder/tensorflow_datasets/cats_vs_dogs/4.0.0. Subsequent calls will reuse this data. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpeaonpwe8/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpeaonpwe8/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting from: /tmp/tmpeaonpwe8/keras/keras_model.ckpt INFO:tensorflow:Warm-starting from: /tmp/tmpeaonpwe8/keras/keras_model.ckpt INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Warm-started 158 variables. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpeaonpwe8/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpeaonpwe8/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6884984, step = 0 INFO:tensorflow:loss = 0.6884984, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpeaonpwe8/model.ckpt. INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpeaonpwe8/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Loss for final step: 0.67705643. INFO:tensorflow:Loss for final step: 0.67705643. <tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f3d7c3822b0>
마찬가지로 평가하려면 Estimator의 평가 함수를 호출합니다.
est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:2325: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. warnings.warn('`Model.state_updates` will be removed in a future version. ' INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:32Z INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:32Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpeaonpwe8/model.ckpt-50 INFO:tensorflow:Restoring parameters from /tmp/tmpeaonpwe8/model.ckpt-50 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 2.42050s INFO:tensorflow:Inference Time : 2.42050s INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:35 INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:35 INFO:tensorflow:Saving dict for global step 50: accuracy = 0.515625, global_step = 50, loss = 0.6688157 INFO:tensorflow:Saving dict for global step 50: accuracy = 0.515625, global_step = 50, loss = 0.6688157 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpeaonpwe8/model.ckpt-50 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpeaonpwe8/model.ckpt-50 {'accuracy': 0.515625, 'loss': 0.6688157, 'global_step': 50}
자세한 내용은 tf.keras.estimator.model_to_estimator
문서를 참조하세요.
Estimator로 객체 기반 체크 포인트 저장
에스티 메이터는 기본적으로 체크 포인트 가이드에 설명 된 개체 그래프가 아닌 변수 이름으로 체크 포인트를 저장합니다. tf.train.Checkpoint
는 이름 기반 체크 포인트를 읽지 만 Estimator의 model_fn
외부로 모델의 일부를 이동할 때 변수 이름이 변경 될 수 있습니다. 포워드 호환성을 위해 객체 기반 체크 포인트를 저장하면 에스티 메이터 내부에서 모델을 학습 한 다음 모델 외부에서 사용하기가 더 쉽습니다.
import tensorflow.compat.v1 as tf_compat
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
def model_fn(features, labels, mode):
net = Net()
opt = tf.keras.optimizers.Adam(0.1)
ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
optimizer=opt, net=net)
with tf.GradientTape() as tape:
output = net(features['x'])
loss = tf.reduce_mean(tf.abs(output - features['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
return tf.estimator.EstimatorSpec(
mode,
loss=loss,
train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
ckpt.step.assign_add(1)),
# Tell the Estimator to save "ckpt" in an object-based format.
scaffold=tf_compat.train.Scaffold(saver=ckpt))
tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 4.4040537, step = 0 INFO:tensorflow:loss = 4.4040537, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 35.247967. INFO:tensorflow:Loss for final step: 35.247967. <tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f3d64534518>
tf.train.Checkpoint
다음의에서 견적의 체크 포인트를로드 할 수 있습니다 model_dir
.
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy() # From est.train(..., steps=10)
10
에스티 메이터에서 저장된 모델
에스티 메이터는 tf.Estimator.export_saved_model을 통해 tf.Estimator.export_saved_model
모델을 내 tf.Estimator.export_saved_model
.
input_column = tf.feature_column.numeric_column("x")
estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])
def input_fn():
return tf.data.Dataset.from_tensor_slices(
({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpczwhe6jk Warning:tensorflow:Using temporary folder as model directory: /tmp/tmpczwhe6jk INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpczwhe6jk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpczwhe6jk', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpczwhe6jk/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpczwhe6jk/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50... INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpczwhe6jk/model.ckpt. INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpczwhe6jk/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50... INFO:tensorflow:Loss for final step: 0.48830828. INFO:tensorflow:Loss for final step: 0.48830828. <tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f3d6452eb00>
Estimator
를 저장하려면 serving_input_receiver
를 만들어야합니다. 이 함수는 tf.Graph
에서받은 원시 데이터를 구문 분석하는 tf.Graph
의 일부를 빌드합니다.
tf.estimator.export
모듈에는 이러한 receivers
구축하는 데 도움이되는 함수가 포함되어 있습니다.
다음 코드는 tf-serving 과 함께 자주 사용되는 직렬화 된 tf.Example
프로토콜 버퍼를 허용하는 feature_columns
기반으로 수신기를 빌드합니다.
tmpdir = tempfile.mkdtemp()
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
tf.feature_column.make_parse_example_spec([input_column]))
estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification'] INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification'] INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression'] INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression'] INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict'] INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict'] INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Restoring parameters from /tmp/tmpczwhe6jk/model.ckpt-50 INFO:tensorflow:Restoring parameters from /tmp/tmpczwhe6jk/model.ckpt-50 INFO:tensorflow:Assets added to graph. INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: /tmp/tmp16t8uhub/from_estimator/temp-1610074656/saved_model.pb INFO:tensorflow:SavedModel written to: /tmp/tmp16t8uhub/from_estimator/temp-1610074656/saved_model.pb
Python에서 해당 모델을로드하고 실행할 수도 있습니다.
imported = tf.saved_model.load(estimator_path)
def predict(x):
example = tf.train.Example()
example.features.feature["x"].float_list.value.extend([x])
return imported.signatures["predict"](
examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.581246]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.32789052]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.418754, 0.581246]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>} {'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.24376468]], dtype=float32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.1321492]], dtype=float32)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7562353 , 0.24376468]], dtype=float32)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>}
tf.estimator.export.build_raw_serving_input_receiver_fn
당신이 원시 텐서보다는 걸릴 입력 기능을 만들 수 있습니다 tf.train.Example
들.
Estimator와 함께 tf.distribute.Strategy
사용 (제한적 지원)
tf.estimator
는 원래 비동기 매개 변수 서버 접근 방식을 지원했던 분산 학습 TensorFlow API입니다. tf.estimator
지금 지원 tf.distribute.Strategy
. tf.estimator
사용하는 경우 코드를 거의 변경하지 않고 분산 학습으로 변경할 수 있습니다. 이를 통해 Estimator 사용자는 이제 여러 GPU 및 여러 작업자에서 동기식 분산 학습을 수행하고 TPU를 사용할 수 있습니다. 그러나 Estimator의 이러한 지원은 제한적입니다. 자세한 내용은 아래 의 지금 지원되는 기능 섹션을 확인하세요.
Estimator와 함께 tf.distribute.Strategy
를 사용하는 tf.distribute.Strategy
와 약간 다릅니다. strategy.scope
를 사용하는 대신 이제 Estimator의 RunConfig
에 전략 개체를 전달합니다.
자세한 내용은 배포 된 교육 가이드 를 참조하십시오.
다음은 미리 LinearRegressor
Estimator LinearRegressor
및 MirroredStrategy
LinearRegressor
를 보여주는 코드 스 니펫입니다.
mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
feature_columns=[tf.feature_column.numeric_column('feats')],
optimizer='SGD',
config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Initializing RunConfig with distribution strategies. INFO:tensorflow:Initializing RunConfig with distribution strategies. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Not using Distribute Coordinator. Warning:tensorflow:Using temporary folder as model directory: /tmp/tmp4uihzu_a Warning:tensorflow:Using temporary folder as model directory: /tmp/tmp4uihzu_a INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp4uihzu_a', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp4uihzu_a', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f3e84699518>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
여기에서는 미리 만들어진 Estimator를 사용하지만 동일한 코드가 사용자 지정 Estimator에서도 작동합니다. train_distribute
는 훈련 배포 방법을 결정하고 eval_distribute
는 평가 배포 방법을 결정합니다. 이것은 훈련과 평가 모두에 동일한 전략을 사용하는 Keras와의 또 다른 차이점입니다.
이제 입력 함수로이 에스티 메이터를 훈련하고 평가할 수 있습니다.
def input_fn():
dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. Warning:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/util.py:96: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version. Instructions for updating: Use the iterator's `initializer` property instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp4uihzu_a/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp4uihzu_a/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 1.0, step = 0 INFO:tensorflow:loss = 1.0, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10... INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmp4uihzu_a/model.ckpt. INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmp4uihzu_a/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10... INFO:tensorflow:Loss for final step: 2.877698e-13. INFO:tensorflow:Loss for final step: 2.877698e-13. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:41Z INFO:tensorflow:Starting evaluation at 2021-01-08T02:57:41Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmp4uihzu_a/model.ckpt-10 INFO:tensorflow:Restoring parameters from /tmp/tmp4uihzu_a/model.ckpt-10 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.26266s INFO:tensorflow:Inference Time : 0.26266s INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:42 INFO:tensorflow:Finished evaluation at 2021-01-08-02:57:42 INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994 INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmp4uihzu_a/model.ckpt-10 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmp4uihzu_a/model.ckpt-10 {'average_loss': 1.4210855e-14, 'label/mean': 1.0, 'loss': 1.4210855e-14, 'prediction/mean': 0.99999994, 'global_step': 10}
여기에서 Estimator와 Keras의 또 다른 차이점은 입력 처리입니다. Keras에서 데이터 세트의 각 배치는 여러 복제본에 자동으로 분할됩니다. 그러나 Estimator에서는 자동 일괄 분할을 수행하지 않으며 다른 작업자간에 데이터를 자동으로 분할하지 않습니다. 작업자와 장치간에 데이터를 배포하는 방법을 완전히 제어 할 수 있으며 데이터 배포 방법을 지정하려면 input_fn
을 제공해야합니다.
input_fn
은 작업 input_fn
한 번 호출되므로 작업 자당 하나의 데이터 세트를 제공합니다. 그런 다음 해당 데이터 세트의 배치 하나가 해당 작업자의 복제본 하나에 공급되어 작업자 1 명의 복제본 N 개에 대해 N 개의 배치를 사용합니다. 즉, 데이터 집합에 의해 반환 input_fn
크기의 일괄 제공해야 PER_REPLICA_BATCH_SIZE
. 그리고 단계의 전역 배치 크기는 PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync
로 얻을 수 있습니다.
다중 작업자 교육을 수행 할 때 작업자간에 데이터를 분할하거나 각각에 대해 임의의 시드로 섞어 야합니다. Estimator를 사용한 다중 작업자 교육 자습서에서이를 수행하는 방법의 예를 확인할 수 있습니다.
마찬가지로 다중 작업자 및 매개 변수 서버 전략도 사용할 수 있습니다. 코드는 동일하게 유지되지만 tf.estimator.train_and_evaluate
를 사용하고 클러스터에서 실행중인 각 바이너리에 대해 TF_CONFIG
환경 변수를 설정해야합니다.
현재 지원되는 것은 무엇입니까?
TPUStrategy
제외한 모든 전략을 사용한 Estimator 학습은 제한적으로 지원됩니다. 기본 교육 및 평가는 작동하지만 v1.train.Scaffold
와 같은 여러 고급 기능은 v1.train.Scaffold
하지 않습니다. 또한이 통합에는 여러 버그가있을 수 있으며이 지원을 적극적으로 개선 할 계획이 없습니다 (Keras 및 사용자 지정 교육 루프 지원에 중점을 둡니다). 가능하다면 해당 API와 함께 tf.distribute
를 대신 사용하는 것이 좋습니다.
교육 API | MirroredStrategy | TPU 전략 | MultiWorkerMirrored 전략 | CentralStorage 전략 | ParameterServerStrategy |
---|---|---|---|---|---|
Estimator API | 제한된 지원 | 지원되지 않음 | 제한된 지원 | 제한된 지원 | 제한된 지원 |
예제 및 튜토리얼
다음은 Estimator에서 다양한 전략을 사용하는 방법을 보여주는 몇 가지 종단 간 예제입니다.
- Estimator를 사용한 다중 작업자 교육 자습서 는 MNIST 데이터 세트에서
MultiWorkerMirroredStrategy
를 사용하여 여러 작업자와 교육 하는 방법을 보여줍니다. - Kubernetes 템플릿을 사용하여
tensorflow/ecosystem
에서 배포 전략 으로 다중 작업자 교육 을 실행하는 종단 간 예제입니다.tf.keras.estimator.model_to_estimator
모델로 시작하여tf.keras.estimator.model_to_estimator
API를 사용하여 Estimator로 변환합니다. -
MirroredStrategy
또는MultiWorkerMirroredStrategy
사용하여 훈련 할 수있는 공식 ResNet50 모델.