Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tfl.test_utils.run_training_loop

View source on GitHub

Trains models and prints debug info.

tfl.test_utils.run_training_loop(
    config, training_data, keras_model, plot_path=None, input_dtype=np.float32,
    label_dtype=np.float32
)

Args:

  • config: dictionary of test case parameters. See tests for TensorFlow Lattice layers.
  • training_data: tripple: (training_inputs, labels, raw_training_inputs) where training_inputs and labels are proper data to train models passed via other parameters and raw_training_inputs are representation of training_inputs for visualization.
  • keras_model: Keras model to train on training_data.
  • plot_path: if specified it should be a string which represents file name where to save model output vs ground truth visualization as png. Supported only for 1-d and 2-d inputs. For visualisation of 2-d inputs to work - raw_training_data must be a mesh grid.
  • input_dtype: dtype for input conversion.
  • label_dtype: dtype for label conversion.

Returns:

Loss measured on training data and tf.session() if one was initialized explicitly during training.