tfl.test_utils.run_training_loop

Stay organized with collections Save and categorize content based on your preferences.

Trains models and prints debug info.

config dictionary of test case parameters. See tests for TensorFlow Lattice layers.
training_data tripple: (training_inputs, labels, raw_training_inputs) where training_inputs and labels are proper data to train models passed via other parameters and raw_training_inputs are representation of training_inputs for visualization.
keras_model Keras model to train on training_data.
plot_path if specified it should be a string which represents file name where to save model output vs ground truth visualization as png. Supported only for 1-d and 2-d inputs. For visualisation of 2-d inputs to work - raw_training_data must be a mesh grid.
input_dtype dtype for input conversion.
label_dtype dtype for label conversion.

Loss measured on training data and tf.session() if one was initialized explicitly during training.