A model that parameterizes forward pass by model weights.

initial_weights A 2-tuple (trainable, non_trainable) where the two elements are sequences of weights. Weights must be values convertable to tf.Tensor (e.g. numpy.ndarray, Python sequences, etc), but not tf.Tensor values.
predict_on_batch_fn A tf.function decorated callable that takes three arguments, model_weights the same structure as initial_weights, x the first element of batch_input (or input_spec), and training a boolean determinig whether the call is during a training pass (e.g. for Dropout, BatchNormalization, etc). It must return either a tensor of predictions or a structure whose first element (as determined by tf.nest.flatten()) is a tensor of predictions.
loss_fn A callable that takes three arguments, output tensor(s) as output of predict_on_batch that is interpretable by the loss function, label the second element of batch_input, and optional sample_weight that weights the output.
metrics_fns A 3-tuple of callables that initialize the metrics state, update the metrics state, and finalize the metrics values respectively. This can be the result of tff.learning.metrics.create_functional_metric_fnsor custom user written callables.
input_spec A 2-tuple of (x, y) where each element is a nested structure of tf.TensorSpec. x corresponds to batched model inputs that define the shape and dtype of x to predict_on_batch_fn, while y corresponds to batched labels for those inputs that define the shape and dtype of label to loss_fn.





View source


View source


View source

Returns the loss value based on the model output and the label.


View source

Returns tensor(s) interpretable by the loss function.


View source