An instance of tff.Type that represents the type of a single
batch of data to use for training. This type should be constructed with
standard Python containers (such as collections.OrderedDict) of the sort
that are expected as parameters to loss_fn.
model_type
An instance of tff.Type that represents the type of the model.
Similarly to batch_size, this type should be constructed with standard
Python containers (such as collections.OrderedDict) of the sort that are
expected as parameters to loss_fn.
loss_fn
A loss function for the model. Must be a Python function that takes
two parameters, one of them being the model, and the other being a single
batch of data (with types matching batch_type and model_type).
step_size
The step size to use during training (an np.float32).