Base model for TFRS models.

Used in the notebooks

Used in the tutorials

Many recommender models are relatively complex, and do not neatly fit into supervised or unsupervised paradigms. This base class makes it easy to define custom training and test losses for such complex models.

This is done by asking the user to implement the following methods:

  • __init__ to set up your model. Variable, task, loss, and metric initialization should go here.
  • compute_loss to define the training loss. The method takes as input the raw features passed into the model, and returns a loss tensor for training. As part of doing so, it should also update the model's metrics.
  • [Optional] call to define how the model computes its predictions. This is not always necessary: for example, two-tower retrieval models have two well-defined submodels whose call methods are normally used directly.

Note that this base class is a thin conveniece wrapper for tf.keras.Model, and equivalent functionality can easily be achieved by overriding the train_step and test_step methods of a plain Keras model. Doing so also makes it easy to build even more complex training mechanisms, such as the use of different optimizers for different variables, or manipulating gradients.

Keras has an excellent tutorial on how to do this here.



Calls the model on new inputs.

In this case call just reapplies all ops in the graph to the new inputs (e.g. build a new computational graph from the provided inputs).

inputs A tensor or list of tensors.
training Boolean or boolean scalar tensor, indicating whether to run the Network in training mode or inference mode.
mask A mask or list of masks. A mask can be either a tensor or None (no mask).

A tensor if there is a single output, or a list of tensors if there are more than one outputs.


View source

Defines the loss function.

inputs A data structure of tensors: raw inputs to the model. These will usually contain labels and weights as well as features.
training Whether the model is in training mode.

Loss tensor.