![]() |
A model that parameterizes forward pass by model weights.
tff.learning.models.FunctionalModel(
*,
initial_weights: ModelWeights,
forward_pass_fn: Callable[[ModelWeights, Any, bool], variable.BatchOutput],
predict_on_batch_fn: Callable[[ModelWeights, Any, bool], Any],
loss_fn: Callable[[Any, Any, Any], Any],
metrics_fns: tuple[InitializeMetricsStateFn, UpdateMetricsStateFn, FinalizeMetricsFn] = (empty_metrics_state, noop_update_metrics, noop_finalize_metrics),
input_spec: Any
)
Args | |
---|---|
initial_weights
|
A 2-tuple (trainable, non_trainable) where the two
elements are sequences of weights. Weights must be values convertable to
tf.Tensor (e.g. numpy.ndarray , Python sequences, etc), but not
tf.Tensor values.
|
forward_pass_fn
|
A tf.function decorated callable that takes three
arguments, model_weights the same structure as initial_weights ,
batch_input a nested structure of tensors matching input_spec , and
training a boolean determinig whether the call is during a training
pass (e.g. for Dropout, BatchNormalization, etc).
|
predict_on_batch_fn
|
A tf.function decorated callable that takes three
arguments, model_weights the same structure as initial_weights , x
the first element of batch_input (or input_spec ), and training a
boolean determinig whether the call is during a training pass (e.g. for
Dropout, BatchNormalization, etc).
|
loss_fn
|
A callable that takes three arguments, output tensor(s) as
output of predict_on_batch that is interpretable by the loss function,
label the second element of batch_input , and optional
sample_weight that weights the output.
|
metrics_fns
|
A 3-tuple of callables that initialize the metrics state,
update the metrics state, and finalize the metrics values respectively.
This can be the result of tff.learning.metrics.create_functional_metric_fns or custom user written
callables.
|
input_spec
|
A 2-tuple of (x, y) where each element is a nested structure
of tf.TensorSpec that defines the shape and dtypes of batch_input to
forward_pass_fn . x corresponds to batched model inputs and y
corresponds to batched labels for those inputs.
|
Attributes | |
---|---|
initial_weights
|
|
input_spec
|
Methods
finalize_metrics
@tf.function
finalize_metrics( state: types.MetricsState ) -> collections.OrderedDict[str, Any]
forward_pass
@tf.function
forward_pass( model_weights: ModelWeights, batch_input: Any, training: bool = True ) ->
tff.learning.BatchOutput
Runs the forward pass and returns results.
initialize_metrics_state
@tf.function
initialize_metrics_state() -> types.MetricsState
loss
loss(
output: Any, label: Any, sample_weight: Optional[Any] = None
) -> float
Returns the loss value based on the model output and the label.
predict_on_batch
@tf.function
predict_on_batch( model_weights: ModelWeights, x: Any, training: bool = True )
Returns tensor(s) interpretable by the loss function.
update_metrics_state
@tf.function
update_metrics_state( state: GenericMetricsState, labels: Any, batch_output:
tff.learning.BatchOutput
, sample_weight: Optional[Any] = None ) -> GenericMetricsState