|View source on GitHub|
Base model for TFRS models.
tfrs.models.Model( *args, **kwargs )
Used in the notebooks
|Used in the tutorials|
Many recommender models are relatively complex, and do not neatly fit into supervised or unsupervised paradigms. This base class makes it easy to define custom training and test losses for such complex models.
This is done by asking the user to implement the following methods:
__init__to set up your model. Variable, task, loss, and metric initialization should go here.
compute_lossto define the training loss. The method takes as input the raw features passed into the model, and returns a loss tensor for training. As part of doing so, it should also update the model's metrics.
callto define how the model computes its predictions. This is not always necessary: for example, two-tower retrieval models have two well-defined submodels whose
callmethods are normally used directly.
Note that this base class is a thin conveniece wrapper for tf.keras.Model, and
equivalent functionality can easily be achieved by overriding the
test_step methods of a plain Keras model. Doing so also makes it easy
to build even more complex training mechanisms, such as the use of
different optimizers for different variables, or manipulating gradients.
Keras has an excellent tutorial on how to do this here.
call( inputs, training=None, mask=None )
Calls the model on new inputs.
In this case
call just reapplies
all ops in the graph to the new inputs
(e.g. build a new computational graph from the provided inputs).
||Input tensor, or dict/list/tuple of input tensors.|
Boolean or boolean scalar tensor, indicating whether to run
||A mask or list of masks. A mask can be either a tensor or None (no mask).|
|A tensor if there is a single output, or a list of tensors if there are more than one outputs.|
compute_loss( inputs, training: bool = False ) -> tf.Tensor
Defines the loss function.
||A data structure of tensors: raw inputs to the model. These will usually contain labels and weights as well as features.|
||Whether the model is in training mode.|