tfr.extension.premade.TFRBertTask

Task object for tf-ranking BERT.

Inherits From: RankingTask

params the task configuration instance, which can be any of dataclass, ConfigDict, namedtuple, etc.
logging_dir a string pointing to where the model, summaries etc. will be saved. You can also write additional stuff in this directory.
name the task name.

logging_dir

name Returns the name of this module as passed or determined in the ctor.

name_scope Returns a tf.name_scope instance for this class.
non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.
submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
True
list(b.submodules) == [c]
True
list(c.submodules) == []
True

task_config

trainable_variables Sequence of trainable variables owned by this module and its submodules.

variables Sequence of variables owned by this module and its submodules.

Methods

aggregate_logs

View source

Aggregates over logs. This runs on CPU in eager mode.

build_inputs

View source

Returns tf.data.Dataset for tf-ranking BERT task.

build_losses

View source

Standard interface to compute losses.

Args
labels optional label tensors.
model_outputs a nested structure of output tensors.
aux_losses auxiliary loss tensors, i.e. losses in keras.Model.

Returns
The total loss tensor.

build_metrics

View source

Gets streaming metrics for training/validation.

build_model

View source

[Optional] Creates model architecture.

Returns
A model instance.

create_optimizer

Creates an TF optimizer from configurations.

Args
optimizer_config the parameters of the Optimization settings.
runtime_config the parameters of the runtime.
dp_config the parameter of differential privacy.

Returns
A tf.optimizers.Optimizer object.

inference_step

Performs the forward step.

With distribution strategies, this method runs on devices.

Args
inputs a dictionary of input tensors.
model the keras.Model.

Returns
Model outputs.

initialize

View source

Load a pretrained checkpoint (if exists) and then train from iter 0.

process_compiled_metrics

Process and update compiled_metrics.

call when using compile/fit API.

Args
compiled_metrics the compiled metrics (model.compiled_metrics).
labels a tensor or a nested structure of tensors.
model_outputs a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model.

process_metrics

View source

Process and update metrics.

Called when using custom training loop API.

Args
metrics a nested structure of metrics objects. The return of function self.build_metrics.
labels a tensor or a nested structure of tensors.
model_outputs a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model.
**kwargs other args.

reduce_aggregated_logs

View source

Calculates aggregated metrics and writes predictions to csv.

train_step

View source

Does forward and backward.

With distribution strategies, this method runs on devices.

Args
inputs a dictionary of input tensors.
model the model, forward pass definition.
optimizer the optimizer for this training step.
metrics a nested structure of metrics objects.

Returns
A dictionary of logs.

validation_step

View source

Validation step.

With distribution strategies, this method runs on devices.

Args
inputs a dictionary of input tensors.
model the keras.Model.
metrics a nested structure of metrics objects.

Returns
A dictionary of logs.

with_name_scope

Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  @tf.Module.with_name_scope
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

Args
method The method to wrap.

Returns
The original method wrapped such that it enters the module's name scope.

loss 'loss'