Ayuda a proteger la Gran Barrera de Coral con TensorFlow en Kaggle Únete Challenge

Migrar LoggingTensorHook y StopAtStepHook a las devoluciones de llamada de Keras

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

En TensorFlow 1, se utiliza tf.estimator.LoggingTensorHook a monitorear y registrar tensores, mientras tf.estimator.StopAtStepHook ayuda a la formación de parada en una etapa determinada cuando el entrenamiento con tf.estimator.Estimator . Este cuaderno muestra cómo migrar de estas API para sus equivalentes en TensorFlow 2 usando personalizados devoluciones de llamada Keras ( tf.keras.callbacks.Callback ) con Model.fit .

Keras devoluciones de llamada son objetos que se denominan en diferentes puntos durante el entrenamiento / evaluación / predicción en la incorporada en Keras Model.fit / Model.evaluate / Model.predict API. Se puede obtener más información sobre las devoluciones de llamada en los tf.keras.callbacks.Callback documentos de la API, así como los escribiendo su propio devoluciones de llamada y de formación y evaluación con los métodos incorporados (sección Las devoluciones de llamada) El uso de guías. Para migrar de SessionRunHook en TensorFlow 1 a devoluciones de llamada Keras en TensorFlow 2, echa un vistazo a la formación migran con la lógica asistida guía.

Configuración

Comience con importaciones y un conjunto de datos simple con fines de demostración:

import tensorflow as tf
import tensorflow.compat.v1 as tf1
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]

# Define an input function.
def _input_fn():
  return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)

TensorFlow 1: registra tensores y deja de entrenar con las API de tf.estimator

En TensorFlow 1, defines varios ganchos para controlar el comportamiento de entrenamiento. A continuación, se pasa a estos ganchos tf.estimator.EstimatorSpec .

En el siguiente ejemplo:

  • Para supervisar / tensores-registro para ejemplo, pesos modelo o pérdidas-utiliza tf.estimator.LoggingTensorHook ( tf.train.LoggingTensorHook es su alias).
  • Para detener el entrenamiento en un paso específico, se utiliza tf.estimator.StopAtStepHook ( tf.train.StopAtStepHook es su alias).
def _model_fn(features, labels, mode):
  dense = tf1.layers.Dense(1)
  logits = dense(features)
  loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
  optimizer = tf1.train.AdagradOptimizer(0.05)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())

  # Define the stop hook.
  stop_hook = tf1.train.StopAtStepHook(num_steps=2)

  # Access tensors to be logged by names.
  kernel_name = tf.identity(dense.weights[0])
  bias_name = tf.identity(dense.weights[1])
  logging_weight_hook = tf1.train.LoggingTensorHook(
      tensors=[kernel_name, bias_name],
      every_n_iter=1)
  # Log the training loss by the tensor object.
  logging_loss_hook = tf1.train.LoggingTensorHook(
      {'loss from LoggingTensorHook': loss},
      every_n_secs=3)

  # Pass all hooks to `EstimatorSpec`.
  return tf1.estimator.EstimatorSpec(mode,
                                     loss=loss,
                                     train_op=train_op,
                                     training_hooks=[stop_hook,
                                                     logging_weight_hook,
                                                     logging_loss_hook])

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
estimator.train(_input_fn)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp3q__3yt7
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp3q__3yt7', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp3q__3yt7/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.025395721, step = 0
INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-1.0769143]
 [ 1.0241832]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [0.]
INFO:tensorflow:loss from LoggingTensorHook = 0.025395721
INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[-1.1124082]
 [ 0.9824805]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [-0.03549388] (0.026 sec)
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2...
INFO:tensorflow:Saving checkpoints for 2 into /tmp/tmp3q__3yt7/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2...
INFO:tensorflow:Loss for final step: 0.09248222.
<tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f05ec414d10>

TensorFlow 2: registra tensores y deja de entrenar con devoluciones de llamada personalizadas y Model.fit

En TensorFlow 2, cuando se utiliza el incorporado en Keras Model.fit (o Model.evaluate ) para la formación / evaluación, puede configurar tensor de supervisión y formación detener mediante la definición de encargo Keras tf.keras.callbacks.Callback s. A continuación, los pases al callbacks parámetro de Model.fit (o Model.evaluate ). (Más información en la Escritura de su propio devoluciones de llamada guía).

En el siguiente ejemplo:

  • Para recrear las funcionalidades de StopAtStepHook , definir una devolución de llamada (llamada StopAtStepCallback abajo) donde se reemplaza el on_batch_end método para la formación de parada después de un cierto número de pasos.
  • Para recrear el LoggingTensorHook comportamiento, definir una devolución de llamada personalizado ( LoggingTensorCallback ) donde registrar e imprimir los tensores registra manualmente, ya que el acceso a los tensores de nombres no es compatible. También puede implementar la frecuencia de registro dentro de la devolución de llamada personalizada. El siguiente ejemplo imprimirá los pesos cada dos pasos. También son posibles otras estrategias como registrar cada N segundos.
class StopAtStepCallback(tf.keras.callbacks.Callback):
  def __init__(self, stop_step=None):
    super().__init__()
    self._stop_step = stop_step

  def on_batch_end(self, batch, logs=None):
    if self.model.optimizer.iterations >= self._stop_step:
      self.model.stop_training = True
      print('\nstop training now')

class LoggingTensorCallback(tf.keras.callbacks.Callback):
  def __init__(self, every_n_iter):
      super().__init__()
      self._every_n_iter = every_n_iter
      self._log_count = every_n_iter

  def on_batch_end(self, batch, logs=None):
    if self._log_count > 0:
      self._log_count -= 1
      print("Logging Tensor Callback: dense/kernel:",
            model.layers[0].weights[0])
      print("Logging Tensor Callback: dense/bias:",
            model.layers[0].weights[1])
      print("Logging Tensor Callback loss:", logs["loss"])
    else:
      self._log_count -= self._every_n_iter

Cuando haya terminado, pasar la nueva callbacks- StopAtStepCallback y LoggingTensorCallback -al callbacks parámetro de Model.fit :

dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer, "mse")

# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
model.fit(dataset, callbacks=[StopAtStepCallback(stop_step=2),
                              LoggingTensorCallback(every_n_iter=2)])
1/3 [=========>....................] - ETA: 0s - loss: 3.2473Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-0.27049014],
       [-0.73790836]], dtype=float32)>
Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.04980864], dtype=float32)>
Logging Tensor Callback loss: 3.2473244667053223

stop training now
Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-0.22285421],
       [-0.6911988 ]], dtype=float32)>
Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.09196297], dtype=float32)>
Logging Tensor Callback loss: 5.644947052001953
3/3 [==============================] - 0s 4ms/step - loss: 5.6449
<keras.callbacks.History at 0x7f053022be90>

Próximos pasos

Obtenga más información sobre las devoluciones de llamada en:

También puede encontrar útiles los siguientes recursos relacionados con la migración: