tff.learning.models.FunctionalModel

Stay organized with collections Save and categorize content based on your preferences.

A model that parameterizes forward pass by model weights.

initial_weights A 2-tuple (trainable, non_trainable) where the two elements are sequences of weights. Weights must be values convertable to tf.Tensor (e.g. numpy.ndarray, Python sequences, etc), but not tf.Tensor values.
forward_pass_fn A tf.function decorated callable that takes three arguments, model_weights the same structure as initial_weights, batch_input a nested structure of tensors matching input_spec, and training a boolean determinig whether the call is during a training pass (e.g. for Dropout, BatchNormalization, etc).
predict_on_batch_fn A tf.function decorated callable that takes three arguments, model_weights the same structure as initial_weights, x the first element of batch_input (or input_spec), and training a boolean determinig whether the call is during a training pass (e.g. for Dropout, BatchNormalization, etc).
loss_fn A callable that takes three arguments, output tensor(s) as output of predict_on_batch that is interpretable by the loss function, label the second element of batch_input, and optional sample_weight that weights the output.
metrics_fns A 3-tuple of callables that initialize the metrics state, update the metrics state, and finalize the metrics values respectively. This can be the result of tff.learning.metrics.create_functional_metric_fnsor custom user written callables.
input_spec A 2-tuple of (x, y) where each element is a nested structure of tf.TensorSpec that defines the shape and dtypes of batch_input to forward_pass_fn. x corresponds to batched model inputs and y corresponds to batched labels for those inputs.

initial_weights

input_spec

Methods

finalize_metrics

View source

forward_pass

View source

Runs the forward pass and returns results.

initialize_metrics_state

View source

loss

View source

Returns the loss value based on the model output and the label.

predict_on_batch

View source

Returns tensor(s) interpretable by the loss function.

update_metrics_state

View source