Google I / O에서 기조 연설, 제품 세션, 워크샵 등을 시청 하세요. 재생 목록보기

케라스 모델로 추정기 생성하기

TensorFlow.org에서 보기 구글 코랩(Colab)에서 실행하기 깃허브(GitHub) 소스 보기 Download notebook

개요

텐서플로는 tf.keras 모델을 텐서플로 추정기(estimator)로 손쉽게 변환하는 기능을 지원하고 있습니다. 본 튜토리얼에서는 이 변환 과정의 전반적인 예시를 소개합니다.

설정

import tensorflow as tf

import numpy as np
import tensorflow_datasets as tfds

간단한 케라스 모델 만들기

케라스에서는 여러 겹의 층을 쌓아 모델을 만들 수 있습니다. 일반적으로 모델은 층의 그래프로 구성됩니다. 이 중 가장 흔한 형태는 적층형 구조를 갖고 있는 tf.keras.Sequential 모델입니다.

간단한 완전히 연결 네트워크(다층 퍼셉트론)를 만들어봅시다:

model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(3)
])

모델을 컴파일한 후, 모델 구조를 요약해 출력할 수 있습니다.

model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              optimizer='adam')
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                80        
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 3)                 51        
=================================================================
Total params: 131
Trainable params: 131
Non-trainable params: 0
_________________________________________________________________

입력 함수 만들기

데이터셋 API를 사용해 대규모 데이터셋을 다루거나 여러 장치에서 훈련할 수 있습니다.

텐서플로 추정기는 입력 파이프라인(input pipeline)이 언제 어떻게 생성되었는지 제어해야 합니다. 이를 위해서는 "입력 함수", 즉 input_fn이 필요합니다. 추정기는 이 함수를 별도의 매개변수 설정 없이 호출하게 됩니다. 이때 input_fntf.data.Dataset 객체를 반환해야 합니다.

def input_fn():
  split = tfds.Split.TRAIN
  dataset = tfds.load('iris', split=split, as_supervised=True)
  dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))
  dataset = dataset.batch(32).repeat()
  return dataset

input_fn이 잘 구현되었는지 확인해봅니다.

for features_batch, labels_batch in input_fn().take(1):
  print(features_batch)
  print(labels_batch)
Downloading and preparing dataset iris/2.0.0 (download: 4.44 KiB, generated: Unknown size, total: 4.44 KiB) to /home/kbuilder/tensorflow_datasets/iris/2.0.0...
Shuffling and writing examples to /home/kbuilder/tensorflow_datasets/iris/2.0.0.incomplete3PQKJE/iris-train.tfrecord
Dataset iris downloaded and prepared to /home/kbuilder/tensorflow_datasets/iris/2.0.0. Subsequent calls will reuse this data.
{'dense_input': <tf.Tensor: shape=(32, 4), dtype=float32, numpy=
array([[5.1, 3.4, 1.5, 0.2],
       [7.7, 3. , 6.1, 2.3],
       [5.7, 2.8, 4.5, 1.3],
       [6.8, 3.2, 5.9, 2.3],
       [5.2, 3.4, 1.4, 0.2],
       [5.6, 2.9, 3.6, 1.3],
       [5.5, 2.6, 4.4, 1.2],
       [5.5, 2.4, 3.7, 1. ],
       [4.6, 3.4, 1.4, 0.3],
       [7.7, 2.8, 6.7, 2. ],
       [7. , 3.2, 4.7, 1.4],
       [4.6, 3.2, 1.4, 0.2],
       [6.5, 3. , 5.2, 2. ],
       [5.5, 4.2, 1.4, 0.2],
       [5.4, 3.9, 1.3, 0.4],
       [5. , 3.5, 1.3, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [4.8, 3. , 1.4, 0.1],
       [6.5, 3. , 5.8, 2.2],
       [7.6, 3. , 6.6, 2.1],
       [6.7, 3.3, 5.7, 2.1],
       [7.9, 3.8, 6.4, 2. ],
       [6.7, 3. , 5.2, 2.3],
       [5.8, 4. , 1.2, 0.2],
       [6.3, 2.5, 5. , 1.9],
       [5. , 3. , 1.6, 0.2],
       [6.9, 3.1, 5.1, 2.3],
       [6.1, 3. , 4.6, 1.4],
       [5.8, 2.7, 4.1, 1. ],
       [5.2, 2.7, 3.9, 1.4],
       [6.7, 3. , 5. , 1.7],
       [5.7, 2.6, 3.5, 1. ]], dtype=float32)>}
tf.Tensor([0 2 1 2 0 1 1 1 0 2 1 0 2 0 0 0 0 0 2 2 2 2 2 0 2 0 2 1 1 1 1 1], shape=(32,), dtype=int64)

tf.keras.model을 추정기로 변환하기

tf.keras.modeltf.keras.estimator.model_to_estimator 함수를 이용해 tf.estimator.Estimator 객체로 변환함으로써 tf.estimator API를 통해 훈련할 수 있습니다.

import tempfile
model_dir = tempfile.mkdtemp()
keras_estimator = tf.keras.estimator.model_to_estimator(
    keras_model=model, model_dir=model_dir)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
INFO:tensorflow:Using the Keras model provided.
INFO:tensorflow:Using the Keras model provided.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py:220: set_learning_phase (from tensorflow.python.keras.backend) is deprecated and will be removed after 2020-10-11.
Instructions for updating:
Simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py:220: set_learning_phase (from tensorflow.python.keras.backend) is deprecated and will be removed after 2020-10-11.
Instructions for updating:
Simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp01gtmwr4', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp01gtmwr4', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

추정기를 훈련한 후 평가합니다.

keras_estimator.train(input_fn=input_fn, steps=500)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print('평가 결과: {}'.format(eval_result))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp01gtmwr4/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp01gtmwr4/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmp/tmp01gtmwr4/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmp/tmp01gtmwr4/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 4 variables.
INFO:tensorflow:Warm-started 4 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp01gtmwr4/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp01gtmwr4/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.9010491, step = 0
INFO:tensorflow:loss = 2.9010491, step = 0
INFO:tensorflow:global_step/sec: 422.413
INFO:tensorflow:global_step/sec: 422.413
INFO:tensorflow:loss = 1.0897408, step = 100 (0.238 sec)
INFO:tensorflow:loss = 1.0897408, step = 100 (0.238 sec)
INFO:tensorflow:global_step/sec: 522.818
INFO:tensorflow:global_step/sec: 522.818
INFO:tensorflow:loss = 0.663774, step = 200 (0.191 sec)
INFO:tensorflow:loss = 0.663774, step = 200 (0.191 sec)
INFO:tensorflow:global_step/sec: 514.363
INFO:tensorflow:global_step/sec: 514.363
INFO:tensorflow:loss = 0.54841423, step = 300 (0.194 sec)
INFO:tensorflow:loss = 0.54841423, step = 300 (0.194 sec)
INFO:tensorflow:global_step/sec: 508.819
INFO:tensorflow:global_step/sec: 508.819
INFO:tensorflow:loss = 0.48413682, step = 400 (0.197 sec)
INFO:tensorflow:loss = 0.48413682, step = 400 (0.197 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500...
INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmp01gtmwr4/model.ckpt.
INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmp01gtmwr4/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500...
INFO:tensorflow:Loss for final step: 0.6252237.
INFO:tensorflow:Loss for final step: 0.6252237.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_v1.py:2048: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-09-23T04:23:15Z
INFO:tensorflow:Starting evaluation at 2020-09-23T04:23:15Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp01gtmwr4/model.ckpt-500
INFO:tensorflow:Restoring parameters from /tmp/tmp01gtmwr4/model.ckpt-500
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.17461s
INFO:tensorflow:Inference Time : 0.17461s
INFO:tensorflow:Finished evaluation at 2020-09-23-04:23:15
INFO:tensorflow:Finished evaluation at 2020-09-23-04:23:15
INFO:tensorflow:Saving dict for global step 500: global_step = 500, loss = 0.41598433
INFO:tensorflow:Saving dict for global step 500: global_step = 500, loss = 0.41598433
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmp/tmp01gtmwr4/model.ckpt-500
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmp/tmp01gtmwr4/model.ckpt-500
평가 결과: {'loss': 0.41598433, 'global_step': 500}