View source on GitHub

Trains models and prints debug info.

config dictionary of test case parameters. See tests for TensorFlow Lattice layers.
training_data tripple: (training_inputs, labels, raw_training_inputs) where training_inputs and labels are proper data to train models passed via other parameters and raw_training_inputs are representation of training_inputs for visualization.
keras_model Keras model to train on training_data.
plot_path if specified it should be a string which represents file name where to save model output vs ground truth visualization as png. Supported only for 1-d and 2-d inputs. For visualisation of 2-d inputs to work - raw_training_data must be a mesh grid.
input_dtype dtype for input conversion.
label_dtype dtype for label conversion.

Loss measured on training data and tf.session() if one was initialized explicitly during training.