Trains models and prints debug info.
tfl.test_utils.run_training_loop(
config,
training_data,
keras_model,
plot_path=None,
input_dtype=np.float32,
label_dtype=np.float32
)
Args |
config
|
dictionary of test case parameters. See tests for TensorFlow Lattice
layers.
|
training_data
|
tripple: (training_inputs, labels, raw_training_inputs) where
training_inputs and labels are proper data to train models passed via
other parameters and raw_training_inputs are representation of
training_inputs for visualization.
|
keras_model
|
Keras model to train on training_data.
|
plot_path
|
if specified it should be a string which represents file
name where to save model output vs ground truth visualization as png.
Supported only for 1-d and 2-d inputs. For visualisation of 2-d inputs
to work - raw_training_data must be a mesh grid.
|
input_dtype
|
dtype for input conversion.
|
label_dtype
|
dtype for label conversion.
|
Returns |
Loss measured on training data and tf.session() if one was initialized
explicitly during training.
|