tfr.keras.pipeline.SimplePipeline

Pipleine for single-task training.

Inherits From: ModelFitPipeline, AbstractPipeline

This handles a single loss and works with SimpleDatasetBuilder. This can also work with MultiLabelDatasetBuilder. In this case, the same loss, as well as all metrics, will be applied to all labels and their predictions uniformly.

Use subclassing to customize the loss and metrics.

Example usage:

context_feature_spec = {}
example_feature_spec = {
    "example_feature_1": tf.io.FixedLenFeature(
        shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec = {
    "utility": tf.io.FixedLenFeature(
        shape=(1,), dtype=tf.float32, default_value=0.0)
}
dataset_hparams = DatasetHparams(
    train_input_pattern="train.dat",
    valid_input_pattern="valid.dat",
    train_batch_size=128,
    valid_batch_size=128)
pipeline_hparams = pipeline.PipelineHparams(
    model_dir="model/",
    num_epochs=2,
    steps_per_epoch=5,
    validation_steps=2,
    learning_rate=0.01,
    loss="softmax_loss")
model_builder = SimpleModelBuilder(
    context_feature_spec, example_feature_spec, mask_feature_name)
dataset_builder = SimpleDatasetBuilder(
    context_feature_spec,
    example_feature_spec,
    mask_feature_name,
    label_spec,
    dataset_hparams)
pipeline = SimplePipeline(model_builder, dataset_builder, pipeline_hparams)
pipeline.train_and_validate(verbose=1)

model_builder A ModelBuilder instance for model fit.
dataset_builder An AbstractDatasetBuilder instance to load train and validate datasets and create signatures for SavedModel.
hparams A dict containing model hyperparameters.

Methods

build_callbacks

View source

Sets up Callbacks.

Example usage:

model_builder = ModelBuilder(...)
dataset_builder = DatasetBuilder(...)
hparams = PipelineHparams(...)
pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)
callbacks = pipeline.build_callbacks()

Returns
A list of tf.keras.callbacks.Callback or a tf.keras.callbacks.CallbackList for tensorboard and checkpoint.

build_loss

View source

See AbstractPipeline.

build_metrics

View source

See AbstractPipeline.

build_weighted_metrics

View source

See AbstractPipeline.

export_saved_model

View source

Exports the trained model with signatures.

Example usage:

model_builder = ModelBuilder(...)
dataset_builder = DatasetBuilder(...)
hparams = PipelineHparams(...)
pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)
pipeline.export_saved_model(model_builder.build(), 'saved_model/')

Args
model Model to be saved.
export_to Specifies the directory the model is be exported to.
checkpoint If given, export the model with weights from this checkpoint.

train_and_validate

View source

Main function to train the model with TPU strategy.

Example usage:

context_feature_spec = {}
example_feature_spec = {
    "example_feature_1": tf.io.FixedLenFeature(
        shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec = {
    "utility": tf.io.FixedLenFeature(
        shape=(1,), dtype=tf.float32, default_value=0.0)
}
dataset_hparams = DatasetHparams(
    train_input_pattern="train.dat",
    valid_input_pattern="valid.dat",
    train_batch_size=128,
    valid_batch_size=128)
pipeline_hparams = pipeline.PipelineHparams(
    model_dir="model/",
    num_epochs=2,
    steps_per_epoch=5,
    validation_steps=2,
    learning_rate=0.01,
    loss="softmax_loss")
model_builder = SimpleModelBuilder(
    context_feature_spec, example_feature_spec, mask_feature_name)
dataset_builder = SimpleDatasetBuilder(
    context_feature_spec,
    example_feature_spec,
    mask_feature_name,
    label_spec,
    dataset_hparams)
pipeline = BasicModelFitPipeline(
    model_builder, dataset_builder, pipeline_hparams)
pipeline.train_and_validate(verbose=1)

Args
verbose An int for the verbosity level.