![]() |
Base model for TFRS models.
tfrs.models.Model(
*args, **kwargs
)
Used in the notebooks
Used in the tutorials |
---|
Many recommender models are relatively complex, and do not neatly fit into supervised or unsupervised paradigms. This base class makes it easy to define custom training and test losses for such complex models.
This is done by asking the user to implement the following methods:
__init__
to set up your model. Variable, task, loss, and metric initialization should go here.compute_loss
to define the training loss. The method takes as input the raw features passed into the model, and returns a loss tensor for training. As part of doing so, it should also update the model's metrics.- [Optional]
call
to define how the model computes its predictions. This is not always necessary: for example, two-tower retrieval models have two well-defined submodels whosecall
methods are normally used directly.
Note that this base class is a thin conveniece wrapper for tf.keras.Model, and
equivalent functionality can easily be achieved by overriding the train_step
and test_step
methods of a plain Keras model. Doing so also makes it easy
to build even more complex training mechanisms, such as the use of
different optimizers for different variables, or manipulating gradients.
Keras has an excellent tutorial on how to do this here.
Methods
call
call(
inputs, training=None, mask=None
)
Calls the model on new inputs.
In this case call
just reapplies
all ops in the graph to the new inputs
(e.g. build a new computational graph from the provided inputs).
Arguments | |
---|---|
inputs
|
A tensor or list of tensors. |
training
|
Boolean or boolean scalar tensor, indicating whether to run
the Network in training mode or inference mode.
|
mask
|
A mask or list of masks. A mask can be either a tensor or None (no mask). |
Returns | |
---|---|
A tensor if there is a single output, or a list of tensors if there are more than one outputs. |
compute_loss
compute_loss(
inputs,
training: bool = False
) -> tf.Tensor
Defines the loss function.
Args | |
---|---|
inputs
|
A data structure of tensors: raw inputs to the model. These will usually contain labels and weights as well as features. |
training
|
Whether the model is in training mode. |
Returns | |
---|---|
Loss tensor. |