Attend the Women in ML Symposium on December 7 Register now

tfm.nlp.tasks.TaggingTask

Stay organized with collections Save and categorize content based on your preferences.

Task object for tagging (e.g., NER or POS).

Inherits From: Task

params the task configuration instance, which can be any of dataclass, ConfigDict, namedtuple, etc.
logging_dir a string pointing to where the model, summaries etc. will be saved. You can also write additional stuff in this directory.
name the task name.

logging_dir

task_config

Methods

aggregate_logs

View source

Aggregates over logs returned from a validation step.

build_inputs

View source

Returns tf.data.Dataset for sentence_prediction task.

build_losses

View source

Standard interface to compute losses.

Args
labels optional label tensors.
model_outputs a nested structure of output tensors.
aux_losses auxiliary loss tensors, i.e. losses in keras.Model.

Returns
The total loss tensor.

build_metrics

View source

Gets streaming metrics for training/validation.

build_model

View source

[Optional] Creates model architecture.

Returns
A model instance.

create_optimizer

View source

Creates an TF optimizer from configurations.

Args
optimizer_config the parameters of the Optimization settings.
runtime_config the parameters of the runtime.
dp_config the parameter of differential privacy.

Returns
A tf.optimizers.Optimizer object.

inference_step

View source

Performs the forward step.

initialize

View source

[Optional] A callback function used as CheckpointManager's init_fn.

This function will be called when no checkpoint is found for the model. If there is a checkpoint, the checkpoint will be loaded and this function will not be called. You can use this callback function to load a pretrained checkpoint, saved under a directory other than the model_dir.

Args
model The keras.Model built or used by this task.

process_compiled_metrics

View source

Process and update compiled_metrics.

call when using compile/fit API.

Args
compiled_metrics the compiled metrics (model.compiled_metrics).
labels a tensor or a nested structure of tensors.
model_outputs a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model.

process_metrics

View source

Process and update metrics.

Called when using custom training loop API.

Args
metrics a nested structure of metrics objects. The return of function self.build_metrics.
labels a tensor or a nested structure of tensors.
model_outputs a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model.
**kwargs other args.

reduce_aggregated_logs

View source

Reduces aggregated logs over validation steps.

train_step

View source

Does forward and backward.

With distribution strategies, this method runs on devices.

Args
inputs a dictionary of input tensors.
model the model, forward pass definition.
optimizer the optimizer for this training step.
metrics a nested structure of metrics objects.

Returns
A dictionary of logs.

validation_step

View source

Validatation step.

Args
inputs a dictionary of input tensors.
model the keras.Model.
metrics a nested structure of metrics objects.

Returns
A dictionary of logs.

loss 'loss'