Estimator から Keras API に移行する

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

このガイドでは、TensorFlow 1 の tf.estimator.Estimator API から TensorFlow 2 の tf.keras API に移行する方法を示します。最初に、tf.estimator.Estimator を使用してトレーニングと評価のための基本モデルをセットアップして実行します。次に、tf.keras API を使用して TensorFlow 2 で同等の手順を実行します。また、tf.GradientTape をサブクラス化し、tf.keras.Model を使用してトレーニングの手順をカスタマイズする方法も学びます。

  • TensorFlow 1 では、高レベルの tf.estimator.Estimator API を使用して、モデルのトレーニングと評価、推論の実行、およびモデルの保存(提供用)を行うことができます。
  • TensorFlow 2 では、Keras API を使用して、モデルの構築、勾配の適用、 トレーニング、評価、予測などの前述のタスクを実行します。

(モデル/チェックポイント保存ワークフローを TensorFlow 2 に移行するには、SavedModel および Checkpoint 移行ガイドを確認してください。)

セットアップ

インポートと単純なデータセットから始めます。

import tensorflow as tf
import tensorflow.compat.v1 as tf1
2022-12-14 22:24:44.683460: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:24:44.683565: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 22:24:44.683576: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]

TensorFlow 1: tf.estimator.Estimator でトレーニングと評価を行う

この例では、TensorFlow 1 で tf.estimator.Estimator を使用してトレーニングと評価を実行する方法を示します。

いくつかの関数を定義することから始めます。トレーニングデータの入力関数、評価データの評価入力関数、および特徴量とラベルを使用してトレーニング演算がどのように定義されるかを Estimator に伝えるモデル関数です。

def _input_fn():
  return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)

def _eval_input_fn():
  return tf1.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

def _model_fn(features, labels, mode):
  logits = tf1.layers.Dense(1)(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

Estimator をインスタンス化し、モデルをトレーニングします。

estimator = tf1.estimator.Estimator(model_fn=_model_fn)
estimator.train(_input_fn)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpzy9ux6di
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpzy9ux6di', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpzy9ux6di/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.3302933, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3...
INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmpzy9ux6di/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3...
INFO:tensorflow:Loss for final step: 0.9824529.
<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7fdb44fb6e50>

評価セットを使用してプログラムを評価します。

estimator.evaluate(_eval_input_fn)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-12-14T22:24:50
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpzy9ux6di/model.ckpt-3
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Inference Time : 0.25545s
INFO:tensorflow:Finished evaluation at 2022-12-14-22:24:50
INFO:tensorflow:Saving dict for global step 3: global_step = 3, loss = 2.2362373
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 3: /tmpfs/tmp/tmpzy9ux6di/model.ckpt-3
{'loss': 2.2362373, 'global_step': 3}

TensorFlow 2: 組み込みの Keras メソッドを使用してトレーニングと評価を行う

この例では、TensorFlow 2 で Keras Model.fitModel.evaluate を使用してトレーニングと評価を実行する方法を示します(詳細については、組み込みメソッドを使用したトレーニングと評価ガイドを参照してください)。

  • tf.data.Dataset API を使用してデータセットパイプラインを準備することから始めます。
  • 1 つの線形(tf.keras.layers.Dense)レイヤーを持つ単純な Keras Sequential モデルを定義します。
  • Adagrad オプティマイザをインスタンス化します(tf.keras.optimizers.Adagrad)。
  • optimizer 変数と平均二乗誤差("mse")損失を Model.compile に渡して、トレーニング用のモデルを構成します。
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)

model.compile(optimizer=optimizer, loss="mse")

これで、Model.fit を呼び出してモデルをトレーニングする準備が整いました。

model.fit(dataset)
3/3 [==============================] - 0s 5ms/step - loss: 10.1551
<keras.callbacks.History at 0x7fdb48799a60>

最後に、Model.evaluate を使用してモデルを評価します。

model.evaluate(eval_dataset, return_dict=True)
3/3 [==============================] - 0s 3ms/step - loss: 41.3360
{'loss': 41.33598709106445}

TensorFlow 2: カスタムトレーニングステップと組み込みの Keras メソッドを使用してトレーニングと評価を行う

TensorFlow 2 では、tf.keras.callbacks.Callbacktf.distribute.Strategy などの組み込みのトレーニングサポートを引き続き利用しながら、tf.GradientTape を使用して独自のカスタムトレーニングステップ関数を作成して、フォワードパスとバックワードパスを実行することもできます。(詳細については、Model.fit の処理をカスタマイズするおよびトレーニングループの新規作成を参照してください。)

この例では、tf.keras.Sequential をオーバーライドする Model.train_step をサブクラス化することにより、カスタム tf.keras.Model を作成することから始めます。(tf.keras.Model のサブクラス化について詳しくご覧ください)。そのクラス内で、データのバッチごとに 1 つのトレーニングステップでフォワードパスとバックワードパスを実行するカスタムの train_step 関数を定義します。

class CustomModel(tf.keras.Sequential):
  """A custom sequential model that overrides `Model.train_step`."""

  def train_step(self, data):
    batch_data, labels = data

    with tf.GradientTape() as tape:
      predictions = self(batch_data, training=True)
      # Compute the loss value (the loss function is configured
      # in `Model.compile`).
      loss = self.compiled_loss(labels, predictions)

    # Compute the gradients of the parameters with respect to the loss.
    gradients = tape.gradient(loss, self.trainable_variables)
    # Perform gradient descent by updating the weights/parameters.
    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
    # Update the metrics (includes the metric that tracks the loss).
    self.compiled_metrics.update_state(labels, predictions)
    # Return a dict mapping metric names to the current values.
    return {m.name: m.result() for m in self.metrics}

次に、前と同じように以下を実行します。

dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
      (eval_features, eval_labels)).batch(1)

model = CustomModel([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)

model.compile(optimizer=optimizer, loss="mse")

Model.fit を呼び出してモデルをトレーニングします。

model.fit(dataset)
3/3 [==============================] - 0s 3ms/step - loss: 0.4927
<keras.callbacks.History at 0x7fdb44170880>

最後に、Model.evaluate を使用してプログラムを評価します。

model.evaluate(eval_dataset, return_dict=True)
3/3 [==============================] - 0s 3ms/step - loss: 1.4714
{'loss': 1.4713608026504517}

Next steps

役に立つと思われる追加の Keras リソースは次のとおりです。

次のガイドは、tf.estimator API から分散ストラテジーのワークフローを移行するのに役立ちます。