Migrer l'arrêt anticipé

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Ce bloc-notes montre comment vous pouvez configurer l'entraînement du modèle avec un arrêt anticipé, d'abord dans TensorFlow 1 avec tf.estimator.Estimator et un crochet d'arrêt anticipé, puis dans TensorFlow 2 avec les API Keras ou une boucle d'entraînement personnalisée. L'arrêt précoce est une technique de régularisation qui arrête l'entraînement si, par exemple, la perte de validation atteint un certain seuil.

Dans TensorFlow 2, il existe trois façons de mettre en œuvre l'arrêt anticipé :

Installer

import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds

TensorFlow 1 : Arrêt précoce avec crochet d'arrêt précoce et tf.estimator

Commencez par définir les fonctions pour le chargement et le prétraitement du jeu de données MNIST, ainsi que la définition du modèle à utiliser avec tf.estimator.Estimator :

def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def _input_fn():
  ds_train = tfds.load(
    name='mnist',
    split='train',
    shuffle_files=True,
    as_supervised=True)

  ds_train = ds_train.map(
      normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_train = ds_train.batch(128)
  ds_train = ds_train.repeat(100)
  return ds_train

def _eval_input_fn():
  ds_test = tfds.load(
    name='mnist',
    split='test',
    shuffle_files=True,
    as_supervised=True)
  ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_test = ds_test.batch(128)
  return ds_test

def _model_fn(features, labels, mode):
  flatten = tf1.layers.Flatten()(features)
  features = tf1.layers.Dense(128, 'relu')(flatten)
  logits = tf1.layers.Dense(10)(features)

  loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  optimizer = tf1.train.AdagradOptimizer(0.005)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())

  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

Dans TensorFlow 1, l'arrêt anticipé fonctionne en configurant un hook d'arrêt anticipé avec tf.estimator.experimental.make_early_stopping_hook . Vous passez le hook à la méthode make_early_stopping_hook en tant que paramètre pour should_stop_fn , qui peut accepter une fonction sans aucun argument. La formation s'arrête une fois que should_stop_fn renvoie True .

L'exemple suivant montre comment mettre en œuvre une technique d'arrêt précoce qui limite le temps d'entraînement à un maximum de 20 secondes :

estimator = tf1.estimator.Estimator(model_fn=_model_fn)

start_time = time.time()
max_train_seconds = 20

def should_stop_fn():
  return time.time() - start_time > max_train_seconds

early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
    estimator=estimator,
    should_stop_fn=should_stop_fn,
    run_every_secs=1,
    run_every_steps=None)

train_spec = tf1.estimator.TrainSpec(
    input_fn=_input_fn,
    hooks=[early_stopping_hook])

eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)

tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpocmc6_bo
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpocmc6_bo', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/adagrad.py:77: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.3545606, step = 0
INFO:tensorflow:loss = 2.3545606, step = 0
INFO:tensorflow:global_step/sec: 94.5711
INFO:tensorflow:global_step/sec: 94.5711
INFO:tensorflow:loss = 1.3383636, step = 100 (1.060 sec)
INFO:tensorflow:loss = 1.3383636, step = 100 (1.060 sec)
INFO:tensorflow:global_step/sec: 158.428
INFO:tensorflow:global_step/sec: 158.428
INFO:tensorflow:loss = 0.7937969, step = 200 (0.631 sec)
INFO:tensorflow:loss = 0.7937969, step = 200 (0.631 sec)
INFO:tensorflow:global_step/sec: 287.334
INFO:tensorflow:global_step/sec: 287.334
INFO:tensorflow:loss = 0.69060934, step = 300 (0.349 sec)
INFO:tensorflow:loss = 0.69060934, step = 300 (0.349 sec)
INFO:tensorflow:global_step/sec: 286.658
INFO:tensorflow:global_step/sec: 286.658
INFO:tensorflow:loss = 0.59314424, step = 400 (0.349 sec)
INFO:tensorflow:loss = 0.59314424, step = 400 (0.349 sec)
INFO:tensorflow:global_step/sec: 311.591
INFO:tensorflow:global_step/sec: 311.591
INFO:tensorflow:loss = 0.50495726, step = 500 (0.320 sec)
INFO:tensorflow:loss = 0.50495726, step = 500 (0.320 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 536 vs previous value: 536. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 536 vs previous value: 536. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 538.395
INFO:tensorflow:global_step/sec: 538.395
INFO:tensorflow:loss = 0.43083754, step = 600 (0.186 sec)
INFO:tensorflow:loss = 0.43083754, step = 600 (0.186 sec)
INFO:tensorflow:global_step/sec: 503.72
INFO:tensorflow:global_step/sec: 503.72
INFO:tensorflow:loss = 0.381118, step = 700 (0.198 sec)
INFO:tensorflow:loss = 0.381118, step = 700 (0.198 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 715 vs previous value: 715. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 715 vs previous value: 715. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 482.019
INFO:tensorflow:global_step/sec: 482.019
INFO:tensorflow:loss = 0.49349022, step = 800 (0.207 sec)
INFO:tensorflow:loss = 0.49349022, step = 800 (0.207 sec)
INFO:tensorflow:global_step/sec: 508.316
INFO:tensorflow:global_step/sec: 508.316
INFO:tensorflow:loss = 0.38730466, step = 900 (0.199 sec)
INFO:tensorflow:loss = 0.38730466, step = 900 (0.199 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 987 vs previous value: 987. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 987 vs previous value: 987. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 452.89
INFO:tensorflow:global_step/sec: 452.89
INFO:tensorflow:loss = 0.44916487, step = 1000 (0.219 sec)
INFO:tensorflow:loss = 0.44916487, step = 1000 (0.219 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1042 vs previous value: 1042. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1042 vs previous value: 1042. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 519.401
INFO:tensorflow:global_step/sec: 519.401
INFO:tensorflow:loss = 0.44320562, step = 1100 (0.192 sec)
INFO:tensorflow:loss = 0.44320562, step = 1100 (0.192 sec)
INFO:tensorflow:global_step/sec: 510.25
INFO:tensorflow:global_step/sec: 510.25
INFO:tensorflow:loss = 0.3758085, step = 1200 (0.196 sec)
INFO:tensorflow:loss = 0.3758085, step = 1200 (0.196 sec)
INFO:tensorflow:global_step/sec: 518.649
INFO:tensorflow:global_step/sec: 518.649
INFO:tensorflow:loss = 0.46760654, step = 1300 (0.193 sec)
INFO:tensorflow:loss = 0.46760654, step = 1300 (0.193 sec)
INFO:tensorflow:global_step/sec: 474.056
INFO:tensorflow:global_step/sec: 474.056
INFO:tensorflow:loss = 0.29544568, step = 1400 (0.211 sec)
INFO:tensorflow:loss = 0.29544568, step = 1400 (0.211 sec)
INFO:tensorflow:global_step/sec: 461.406
INFO:tensorflow:global_step/sec: 461.406
INFO:tensorflow:loss = 0.28616875, step = 1500 (0.217 sec)
INFO:tensorflow:loss = 0.28616875, step = 1500 (0.217 sec)
INFO:tensorflow:global_step/sec: 486.2
INFO:tensorflow:global_step/sec: 486.2
INFO:tensorflow:loss = 0.4114887, step = 1600 (0.206 sec)
INFO:tensorflow:loss = 0.4114887, step = 1600 (0.206 sec)
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1678 vs previous value: 1678. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 1678 vs previous value: 1678. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.
INFO:tensorflow:global_step/sec: 507.701
INFO:tensorflow:global_step/sec: 507.701
INFO:tensorflow:loss = 0.35298553, step = 1700 (0.197 sec)
INFO:tensorflow:loss = 0.35298553, step = 1700 (0.197 sec)
INFO:tensorflow:global_step/sec: 490.541
INFO:tensorflow:global_step/sec: 490.541
INFO:tensorflow:loss = 0.3363277, step = 1800 (0.204 sec)
INFO:tensorflow:loss = 0.3363277, step = 1800 (0.204 sec)
INFO:tensorflow:global_step/sec: 460.083
INFO:tensorflow:global_step/sec: 460.083
INFO:tensorflow:loss = 0.50634325, step = 1900 (0.217 sec)
INFO:tensorflow:loss = 0.50634325, step = 1900 (0.217 sec)
INFO:tensorflow:global_step/sec: 436.782
INFO:tensorflow:global_step/sec: 436.782
INFO:tensorflow:loss = 0.2063987, step = 2000 (0.229 sec)
INFO:tensorflow:loss = 0.2063987, step = 2000 (0.229 sec)
INFO:tensorflow:global_step/sec: 475.841
INFO:tensorflow:global_step/sec: 475.841
INFO:tensorflow:loss = 0.27246287, step = 2100 (0.210 sec)
INFO:tensorflow:loss = 0.27246287, step = 2100 (0.210 sec)
INFO:tensorflow:global_step/sec: 483.322
INFO:tensorflow:global_step/sec: 483.322
INFO:tensorflow:loss = 0.31674564, step = 2200 (0.207 sec)
INFO:tensorflow:loss = 0.31674564, step = 2200 (0.207 sec)
INFO:tensorflow:global_step/sec: 442.257
INFO:tensorflow:global_step/sec: 442.257
INFO:tensorflow:loss = 0.3334998, step = 2300 (0.226 sec)
INFO:tensorflow:loss = 0.3334998, step = 2300 (0.226 sec)
INFO:tensorflow:global_step/sec: 476.38
INFO:tensorflow:global_step/sec: 476.38
INFO:tensorflow:loss = 0.2549953, step = 2400 (0.210 sec)
INFO:tensorflow:loss = 0.2549953, step = 2400 (0.210 sec)
INFO:tensorflow:global_step/sec: 467.543
INFO:tensorflow:global_step/sec: 467.543
INFO:tensorflow:loss = 0.21111101, step = 2500 (0.214 sec)
INFO:tensorflow:loss = 0.21111101, step = 2500 (0.214 sec)
INFO:tensorflow:global_step/sec: 497.051
INFO:tensorflow:global_step/sec: 497.051
INFO:tensorflow:loss = 0.15878338, step = 2600 (0.201 sec)
INFO:tensorflow:loss = 0.15878338, step = 2600 (0.201 sec)
INFO:tensorflow:global_step/sec: 461.785
INFO:tensorflow:global_step/sec: 461.785
INFO:tensorflow:loss = 0.31587577, step = 2700 (0.219 sec)
INFO:tensorflow:loss = 0.31587577, step = 2700 (0.219 sec)
INFO:tensorflow:global_step/sec: 493.743
INFO:tensorflow:global_step/sec: 493.743
INFO:tensorflow:loss = 0.47478187, step = 2800 (0.200 sec)
INFO:tensorflow:loss = 0.47478187, step = 2800 (0.200 sec)
INFO:tensorflow:global_step/sec: 463.477
INFO:tensorflow:global_step/sec: 463.477
INFO:tensorflow:loss = 0.2499526, step = 2900 (0.216 sec)
INFO:tensorflow:loss = 0.2499526, step = 2900 (0.216 sec)
INFO:tensorflow:global_step/sec: 538.27
INFO:tensorflow:global_step/sec: 538.27
INFO:tensorflow:loss = 0.34210858, step = 3000 (0.186 sec)
INFO:tensorflow:loss = 0.34210858, step = 3000 (0.186 sec)
INFO:tensorflow:global_step/sec: 508.741
INFO:tensorflow:global_step/sec: 508.741
INFO:tensorflow:loss = 0.2128592, step = 3100 (0.197 sec)
INFO:tensorflow:loss = 0.2128592, step = 3100 (0.197 sec)
INFO:tensorflow:global_step/sec: 519.319
INFO:tensorflow:global_step/sec: 519.319
INFO:tensorflow:loss = 0.40954083, step = 3200 (0.192 sec)
INFO:tensorflow:loss = 0.40954083, step = 3200 (0.192 sec)
INFO:tensorflow:global_step/sec: 468.989
INFO:tensorflow:global_step/sec: 468.989
INFO:tensorflow:loss = 0.34270883, step = 3300 (0.213 sec)
INFO:tensorflow:loss = 0.34270883, step = 3300 (0.213 sec)
INFO:tensorflow:global_step/sec: 479.856
INFO:tensorflow:global_step/sec: 479.856
INFO:tensorflow:loss = 0.26599607, step = 3400 (0.209 sec)
INFO:tensorflow:loss = 0.26599607, step = 3400 (0.209 sec)
INFO:tensorflow:global_step/sec: 495.76
INFO:tensorflow:global_step/sec: 495.76
INFO:tensorflow:loss = 0.21713805, step = 3500 (0.201 sec)
INFO:tensorflow:loss = 0.21713805, step = 3500 (0.201 sec)
INFO:tensorflow:global_step/sec: 440.282
INFO:tensorflow:global_step/sec: 440.282
INFO:tensorflow:loss = 0.22268976, step = 3600 (0.228 sec)
INFO:tensorflow:loss = 0.22268976, step = 3600 (0.228 sec)
INFO:tensorflow:global_step/sec: 495.629
INFO:tensorflow:global_step/sec: 495.629
INFO:tensorflow:loss = 0.28974164, step = 3700 (0.201 sec)
INFO:tensorflow:loss = 0.28974164, step = 3700 (0.201 sec)
INFO:tensorflow:global_step/sec: 468.695
INFO:tensorflow:global_step/sec: 468.695
INFO:tensorflow:loss = 0.37919793, step = 3800 (0.214 sec)
INFO:tensorflow:loss = 0.37919793, step = 3800 (0.214 sec)
INFO:tensorflow:global_step/sec: 529.005
INFO:tensorflow:global_step/sec: 529.005
INFO:tensorflow:loss = 0.23738712, step = 3900 (0.189 sec)
INFO:tensorflow:loss = 0.23738712, step = 3900 (0.189 sec)
INFO:tensorflow:global_step/sec: 494.809
INFO:tensorflow:global_step/sec: 494.809
INFO:tensorflow:loss = 0.29650036, step = 4000 (0.204 sec)
INFO:tensorflow:loss = 0.29650036, step = 4000 (0.204 sec)
INFO:tensorflow:global_step/sec: 525.629
INFO:tensorflow:global_step/sec: 525.629
INFO:tensorflow:loss = 0.20826155, step = 4100 (0.188 sec)
INFO:tensorflow:loss = 0.20826155, step = 4100 (0.188 sec)
INFO:tensorflow:global_step/sec: 509.573
INFO:tensorflow:global_step/sec: 509.573
INFO:tensorflow:loss = 0.26417816, step = 4200 (0.196 sec)
INFO:tensorflow:loss = 0.26417816, step = 4200 (0.196 sec)
INFO:tensorflow:global_step/sec: 472.845
INFO:tensorflow:global_step/sec: 472.845
INFO:tensorflow:loss = 0.31241363, step = 4300 (0.212 sec)
INFO:tensorflow:loss = 0.31241363, step = 4300 (0.212 sec)
INFO:tensorflow:global_step/sec: 510.868
INFO:tensorflow:global_step/sec: 510.868
INFO:tensorflow:loss = 0.32773697, step = 4400 (0.195 sec)
INFO:tensorflow:loss = 0.32773697, step = 4400 (0.195 sec)
INFO:tensorflow:global_step/sec: 492.967
INFO:tensorflow:global_step/sec: 492.967
INFO:tensorflow:loss = 0.28609803, step = 4500 (0.203 sec)
INFO:tensorflow:loss = 0.28609803, step = 4500 (0.203 sec)
INFO:tensorflow:global_step/sec: 507.394
INFO:tensorflow:global_step/sec: 507.394
INFO:tensorflow:loss = 0.32142323, step = 4600 (0.197 sec)
INFO:tensorflow:loss = 0.32142323, step = 4600 (0.197 sec)
INFO:tensorflow:global_step/sec: 475.176
INFO:tensorflow:global_step/sec: 475.176
INFO:tensorflow:loss = 0.14882785, step = 4700 (0.211 sec)
INFO:tensorflow:loss = 0.14882785, step = 4700 (0.211 sec)
INFO:tensorflow:global_step/sec: 503.718
INFO:tensorflow:global_step/sec: 503.718
INFO:tensorflow:loss = 0.312344, step = 4800 (0.198 sec)
INFO:tensorflow:loss = 0.312344, step = 4800 (0.198 sec)
INFO:tensorflow:global_step/sec: 497.659
INFO:tensorflow:global_step/sec: 497.659
INFO:tensorflow:loss = 0.37370217, step = 4900 (0.201 sec)
INFO:tensorflow:loss = 0.37370217, step = 4900 (0.201 sec)
INFO:tensorflow:global_step/sec: 477.736
INFO:tensorflow:global_step/sec: 477.736
INFO:tensorflow:loss = 0.2663591, step = 5000 (0.209 sec)
INFO:tensorflow:loss = 0.2663591, step = 5000 (0.209 sec)
INFO:tensorflow:global_step/sec: 496.559
INFO:tensorflow:global_step/sec: 496.559
INFO:tensorflow:loss = 0.34745598, step = 5100 (0.202 sec)
INFO:tensorflow:loss = 0.34745598, step = 5100 (0.202 sec)
INFO:tensorflow:global_step/sec: 475.989
INFO:tensorflow:global_step/sec: 475.989
INFO:tensorflow:loss = 0.21809828, step = 5200 (0.210 sec)
INFO:tensorflow:loss = 0.21809828, step = 5200 (0.210 sec)
INFO:tensorflow:global_step/sec: 474.464
INFO:tensorflow:global_step/sec: 474.464
INFO:tensorflow:loss = 0.2474105, step = 5300 (0.211 sec)
INFO:tensorflow:loss = 0.2474105, step = 5300 (0.211 sec)
INFO:tensorflow:global_step/sec: 488.774
INFO:tensorflow:global_step/sec: 488.774
INFO:tensorflow:loss = 0.1611641, step = 5400 (0.204 sec)
INFO:tensorflow:loss = 0.1611641, step = 5400 (0.204 sec)
INFO:tensorflow:global_step/sec: 504.942
INFO:tensorflow:global_step/sec: 504.942
INFO:tensorflow:loss = 0.2306528, step = 5500 (0.198 sec)
INFO:tensorflow:loss = 0.2306528, step = 5500 (0.198 sec)
INFO:tensorflow:global_step/sec: 514.058
INFO:tensorflow:global_step/sec: 514.058
INFO:tensorflow:loss = 0.20716992, step = 5600 (0.195 sec)
INFO:tensorflow:loss = 0.20716992, step = 5600 (0.195 sec)
INFO:tensorflow:global_step/sec: 458.899
INFO:tensorflow:global_step/sec: 458.899
INFO:tensorflow:loss = 0.16730343, step = 5700 (0.217 sec)
INFO:tensorflow:loss = 0.16730343, step = 5700 (0.217 sec)
INFO:tensorflow:global_step/sec: 495.197
INFO:tensorflow:global_step/sec: 495.197
INFO:tensorflow:loss = 0.2906361, step = 5800 (0.202 sec)
INFO:tensorflow:loss = 0.2906361, step = 5800 (0.202 sec)
INFO:tensorflow:global_step/sec: 482.244
INFO:tensorflow:global_step/sec: 482.244
INFO:tensorflow:loss = 0.24669808, step = 5900 (0.207 sec)
INFO:tensorflow:loss = 0.24669808, step = 5900 (0.207 sec)
INFO:tensorflow:global_step/sec: 484.946
INFO:tensorflow:global_step/sec: 484.946
INFO:tensorflow:loss = 0.26403594, step = 6000 (0.207 sec)
INFO:tensorflow:loss = 0.26403594, step = 6000 (0.207 sec)
INFO:tensorflow:global_step/sec: 486.74
INFO:tensorflow:global_step/sec: 486.74
INFO:tensorflow:loss = 0.19804293, step = 6100 (0.206 sec)
INFO:tensorflow:loss = 0.19804293, step = 6100 (0.206 sec)
INFO:tensorflow:global_step/sec: 436.727
INFO:tensorflow:global_step/sec: 436.727
INFO:tensorflow:loss = 0.25344175, step = 6200 (0.229 sec)
INFO:tensorflow:loss = 0.25344175, step = 6200 (0.229 sec)
INFO:tensorflow:global_step/sec: 428.73
INFO:tensorflow:global_step/sec: 428.73
INFO:tensorflow:loss = 0.2430937, step = 6300 (0.232 sec)
INFO:tensorflow:loss = 0.2430937, step = 6300 (0.232 sec)
INFO:tensorflow:global_step/sec: 449.706
INFO:tensorflow:global_step/sec: 449.706
INFO:tensorflow:loss = 0.2842306, step = 6400 (0.222 sec)
INFO:tensorflow:loss = 0.2842306, step = 6400 (0.222 sec)
INFO:tensorflow:global_step/sec: 440.873
INFO:tensorflow:global_step/sec: 440.873
INFO:tensorflow:loss = 0.2641199, step = 6500 (0.227 sec)
INFO:tensorflow:loss = 0.2641199, step = 6500 (0.227 sec)
INFO:tensorflow:global_step/sec: 424.092
INFO:tensorflow:global_step/sec: 424.092
INFO:tensorflow:loss = 0.19028814, step = 6600 (0.237 sec)
INFO:tensorflow:loss = 0.19028814, step = 6600 (0.237 sec)
INFO:tensorflow:global_step/sec: 450.352
INFO:tensorflow:global_step/sec: 450.352
INFO:tensorflow:loss = 0.24667627, step = 6700 (0.221 sec)
INFO:tensorflow:loss = 0.24667627, step = 6700 (0.221 sec)
INFO:tensorflow:global_step/sec: 462.774
INFO:tensorflow:global_step/sec: 462.774
INFO:tensorflow:loss = 0.40046322, step = 6800 (0.216 sec)
INFO:tensorflow:loss = 0.40046322, step = 6800 (0.216 sec)
INFO:tensorflow:global_step/sec: 460.854
INFO:tensorflow:global_step/sec: 460.854
INFO:tensorflow:loss = 0.14105138, step = 6900 (0.217 sec)
INFO:tensorflow:loss = 0.14105138, step = 6900 (0.217 sec)
INFO:tensorflow:Requesting early stopping at global step 6916
INFO:tensorflow:Requesting early stopping at global step 6916
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6917...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 6917...
INFO:tensorflow:Saving checkpoints for 6917 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Saving checkpoints for 6917 into /tmp/tmpocmc6_bo/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6917...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 6917...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-09-22T20:07:35
INFO:tensorflow:Starting evaluation at 2021-09-22T20:07:35
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Restoring parameters from /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Inference Time : 0.79520s
INFO:tensorflow:Inference Time : 0.79520s
INFO:tensorflow:Finished evaluation at 2021-09-22-20:07:36
INFO:tensorflow:Finished evaluation at 2021-09-22-20:07:36
INFO:tensorflow:Saving dict for global step 6917: global_step = 6917, loss = 0.227278
INFO:tensorflow:Saving dict for global step 6917: global_step = 6917, loss = 0.227278
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6917: /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 6917: /tmp/tmpocmc6_bo/model.ckpt-6917
INFO:tensorflow:Loss for final step: 0.13882703.
INFO:tensorflow:Loss for final step: 0.13882703.
({'loss': 0.227278, 'global_step': 6917}, [])

TensorFlow 2 : arrêt anticipé avec un rappel intégré et Model.fit

Préparez le jeu de données MNIST et un modèle Keras simple :

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)

ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.005),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

Dans TensorFlow 2, lorsque vous utilisez le Keras Model.fit (ou Model.evaluate ) intégré, vous pouvez configurer un arrêt anticipé en transmettant un rappel tf.keras.callbacks.EarlyStopping au paramètre callbacks de Model.fit .

Le rappel EarlyStopping surveille une métrique spécifiée par l'utilisateur et met fin à la formation lorsqu'elle cesse de s'améliorer. (Consultez Formation et évaluation avec les méthodes intégrées ou la documentation de l' API pour plus d'informations.)

Vous trouverez ci-dessous un exemple de rappel d'arrêt précoce qui surveille la perte et arrête l'entraînement après que le nombre d'époques qui ne montrent aucune amélioration est défini sur 3 ( patience ):

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)

# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)

len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 5s 8ms/step - loss: 0.2371 - sparse_categorical_accuracy: 0.9293 - val_loss: 0.1334 - val_sparse_categorical_accuracy: 0.9611
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.1028 - sparse_categorical_accuracy: 0.9686 - val_loss: 0.1062 - val_sparse_categorical_accuracy: 0.9667
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0703 - sparse_categorical_accuracy: 0.9783 - val_loss: 0.0993 - val_sparse_categorical_accuracy: 0.9707
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0552 - sparse_categorical_accuracy: 0.9822 - val_loss: 0.1040 - val_sparse_categorical_accuracy: 0.9680
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0420 - sparse_categorical_accuracy: 0.9865 - val_loss: 0.1033 - val_sparse_categorical_accuracy: 0.9716
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0387 - sparse_categorical_accuracy: 0.9871 - val_loss: 0.1167 - val_sparse_categorical_accuracy: 0.9691
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0321 - sparse_categorical_accuracy: 0.9893 - val_loss: 0.1396 - val_sparse_categorical_accuracy: 0.9672
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0285 - sparse_categorical_accuracy: 0.9902 - val_loss: 0.1397 - val_sparse_categorical_accuracy: 0.9671
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0263 - sparse_categorical_accuracy: 0.9915 - val_loss: 0.1296 - val_sparse_categorical_accuracy: 0.9715
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0250 - sparse_categorical_accuracy: 0.9915 - val_loss: 0.1440 - val_sparse_categorical_accuracy: 0.9715
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0274 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.1439 - val_sparse_categorical_accuracy: 0.9710
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0241 - sparse_categorical_accuracy: 0.9923 - val_loss: 0.1429 - val_sparse_categorical_accuracy: 0.9718
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0205 - sparse_categorical_accuracy: 0.9929 - val_loss: 0.1451 - val_sparse_categorical_accuracy: 0.9753
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0196 - sparse_categorical_accuracy: 0.9936 - val_loss: 0.1562 - val_sparse_categorical_accuracy: 0.9750
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0214 - sparse_categorical_accuracy: 0.9930 - val_loss: 0.1531 - val_sparse_categorical_accuracy: 0.9748
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0178 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1712 - val_sparse_categorical_accuracy: 0.9731
Epoch 17/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0177 - sparse_categorical_accuracy: 0.9947 - val_loss: 0.1715 - val_sparse_categorical_accuracy: 0.9755
Epoch 18/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1826 - val_sparse_categorical_accuracy: 0.9730
Epoch 19/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0188 - sparse_categorical_accuracy: 0.9942 - val_loss: 0.1919 - val_sparse_categorical_accuracy: 0.9732
Epoch 20/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0190 - sparse_categorical_accuracy: 0.9944 - val_loss: 0.1703 - val_sparse_categorical_accuracy: 0.9777
Epoch 21/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0153 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.1725 - val_sparse_categorical_accuracy: 0.9764
21

TensorFlow 2 : arrêt anticipé avec un rappel personnalisé et Model.fit

Vous pouvez également implémenter un callback d'arrêt précoce personnalisé , qui peut également être passé au paramètre callbacks de Model.fit (ou Model.evaluate ).

Dans cet exemple, le processus de formation est arrêté une fois que self.model.stop_training est défini sur True :

class LimitTrainingTime(tf.keras.callbacks.Callback):
  def __init__(self, max_time_s):
    super().__init__()
    self.max_time_s = max_time_s
    self.start_time = None

  def on_train_begin(self, logs):
    self.start_time = time.time()

  def on_train_batch_end(self, batch, logs):
    now = time.time()
    if now - self.start_time >  self.max_time_s:
      self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0131 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.1911 - val_sparse_categorical_accuracy: 0.9749
Epoch 2/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0133 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1999 - val_sparse_categorical_accuracy: 0.9755
Epoch 3/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0153 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1927 - val_sparse_categorical_accuracy: 0.9770
Epoch 4/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0145 - sparse_categorical_accuracy: 0.9957 - val_loss: 0.2279 - val_sparse_categorical_accuracy: 0.9753
Epoch 5/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2272 - val_sparse_categorical_accuracy: 0.9755
Epoch 6/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0132 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2352 - val_sparse_categorical_accuracy: 0.9747
Epoch 7/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0144 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.2421 - val_sparse_categorical_accuracy: 0.9734
Epoch 8/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.2260 - val_sparse_categorical_accuracy: 0.9785
Epoch 9/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.2472 - val_sparse_categorical_accuracy: 0.9752
Epoch 10/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.2166 - val_sparse_categorical_accuracy: 0.9768
Epoch 11/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0145 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2289 - val_sparse_categorical_accuracy: 0.9781
Epoch 12/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0119 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.2310 - val_sparse_categorical_accuracy: 0.9777
Epoch 13/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0144 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2617 - val_sparse_categorical_accuracy: 0.9781
Epoch 14/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0119 - sparse_categorical_accuracy: 0.9972 - val_loss: 0.3007 - val_sparse_categorical_accuracy: 0.9754
Epoch 15/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0150 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.3014 - val_sparse_categorical_accuracy: 0.9767
Epoch 16/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0143 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2815 - val_sparse_categorical_accuracy: 0.9750
Epoch 17/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2606 - val_sparse_categorical_accuracy: 0.9765
Epoch 18/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0103 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.2602 - val_sparse_categorical_accuracy: 0.9777
Epoch 19/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0098 - sparse_categorical_accuracy: 0.9979 - val_loss: 0.2594 - val_sparse_categorical_accuracy: 0.9780
Epoch 20/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9965 - val_loss: 0.3008 - val_sparse_categorical_accuracy: 0.9755
Epoch 21/100
469/469 [==============================] - 1s 3ms/step - loss: 0.0110 - sparse_categorical_accuracy: 0.9974 - val_loss: 0.2662 - val_sparse_categorical_accuracy: 0.9765
Epoch 22/100
469/469 [==============================] - 1s 1ms/step - loss: 0.0083 - sparse_categorical_accuracy: 0.9978 - val_loss: 0.2587 - val_sparse_categorical_accuracy: 0.9797
22

TensorFlow 2 : arrêt anticipé avec une boucle d'entraînement personnalisée

Dans TensorFlow 2, vous pouvez implémenter un arrêt anticipé dans une boucle d'entraînement personnalisée si vous n'effectuez pas d'entraînement et d'évaluation avec les méthodes Keras intégrées .

Commencez par utiliser les API Keras pour définir un autre modèle simple, un optimiseur, une fonction de perte et des métriques :

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()

Définissez les fonctions de mise à jour des paramètres avec tf.GradientTape et le décorateur @tf.function pour un speedup :

@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
      logits = model(x, training=True)
      loss_value = loss_fn(y, logits)
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
  train_acc_metric.update_state(y, logits)
  train_loss_metric.update_state(y, logits)
  return loss_value

@tf.function
def test_step(x, y):
  logits = model(x, training=False)
  val_acc_metric.update_state(y, logits)
  val_loss_metric.update_state(y, logits)

Ensuite, écrivez une boucle de formation personnalisée, dans laquelle vous pouvez implémenter manuellement votre règle d'arrêt anticipé.

L'exemple ci-dessous montre comment arrêter l'entraînement lorsque la perte de validation ne s'améliore pas sur un certain nombre d'époques :

epochs = 100
patience = 5
wait = 0
best = 0

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
      loss_value = train_step(x_batch_train, y_batch_train)
      if step % 200 == 0:
        print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
        print("Seen so far: %s samples" % ((step + 1) * 128))        
    train_acc = train_acc_metric.result()
    train_loss = train_loss_metric.result()
    train_acc_metric.reset_states()
    train_loss_metric.reset_states()
    print("Training acc over epoch: %.4f" % (train_acc.numpy()))

    for x_batch_val, y_batch_val in ds_test:
      test_step(x_batch_val, y_batch_val)
    val_acc = val_acc_metric.result()
    val_loss = val_loss_metric.result()
    val_acc_metric.reset_states()
    val_loss_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

    # The early stopping strategy: stop the training if `val_loss` does not
    # decrease over a certain number of epochs.
    wait += 1
    if val_loss > best:
      best = val_loss
      wait = 0
    if wait >= patience:
      break
Start of epoch 0
Training loss at step 0: 2.3073
Seen so far: 128 samples
Training loss at step 200: 0.2164
Seen so far: 25728 samples
Training loss at step 400: 0.2186
Seen so far: 51328 samples
Training acc over epoch: 0.9321
Validation acc: 0.9644
Time taken: 1.73s

Start of epoch 1
Training loss at step 0: 0.0733
Seen so far: 128 samples
Training loss at step 200: 0.1581
Seen so far: 25728 samples
Training loss at step 400: 0.1625
Seen so far: 51328 samples
Training acc over epoch: 0.9704
Validation acc: 0.9681
Time taken: 1.23s

Start of epoch 2
Training loss at step 0: 0.0501
Seen so far: 128 samples
Training loss at step 200: 0.1389
Seen so far: 25728 samples
Training loss at step 400: 0.1495
Seen so far: 51328 samples
Training acc over epoch: 0.9779
Validation acc: 0.9703
Time taken: 1.17s

Start of epoch 3
Training loss at step 0: 0.0513
Seen so far: 128 samples
Training loss at step 200: 0.0638
Seen so far: 25728 samples
Training loss at step 400: 0.0930
Seen so far: 51328 samples
Training acc over epoch: 0.9830
Validation acc: 0.9719
Time taken: 1.20s

Start of epoch 4
Training loss at step 0: 0.0251
Seen so far: 128 samples
Training loss at step 200: 0.0482
Seen so far: 25728 samples
Training loss at step 400: 0.0872
Seen so far: 51328 samples
Training acc over epoch: 0.9849
Validation acc: 0.9672
Time taken: 1.18s

Start of epoch 5
Training loss at step 0: 0.0417
Seen so far: 128 samples
Training loss at step 200: 0.0302
Seen so far: 25728 samples
Training loss at step 400: 0.0362
Seen so far: 51328 samples
Training acc over epoch: 0.9878
Validation acc: 0.9703
Time taken: 1.21s

Prochaines étapes