Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge

tfr.keras.pipeline.ModelFitPipeline

Pipeline using model.fit to train a ranking tf.keras.Model.

Inherits From: AbstractPipeline

The ModelFitPipeline class is an abstract class inherit from AbstractPipeline to train and validate a ranking model with model.fit in a distributed strategy specified in hparams.

To be implemented by subclasses:

Example subclass implementation:

class BasicModelFitPipeline(ModelFitPipeline):

  def build_loss(self):
    return tfr.keras.losses.get('softmax_loss')

  def build_metrics(self):
    return [
        tfr.keras.metrics.get(
            'ndcg', topn=topn, name='ndcg_{}'.format(topn)
        ) for topn in [1, 5, 10]
    ]

  def build_weighted_metrics(self):
    return [
        tfr.keras.metrics.get(
            'ndcg', topn=topn, name='weighted_ndcg_{}'.format(topn)
        ) for topn in [1, 5, 10]
    ]

model_builder A ModelBuilder instance for model fit.
dataset_builder An AbstractDatasetBuilder instance to load train and validate datasets and create signatures for SavedModel.
hparams A dict containing model hyperparameters.

Methods

build_callbacks

View source

Sets up Callbacks.

Example usage:

model_builder = ModelBuilder(...)
dataset_builder = DatasetBuilder(...)
hparams = PipelineHparams(...)
pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)
callbacks = pipeline.build_callbacks()

Returns
A list of tf.keras.callbacks.Callback or a tf.keras.callbacks.CallbackList for tensorboard and checkpoint.

build_loss

View source

Returns the loss for model.compile.

Example usage:

pipeline = BasicPipeline(model, train_data, valid_data)
loss = pipeline.build_loss()

Returns
A tf.keras.losses.Loss or a dict or list of tf.keras.losses.Loss.

build_metrics

View source

Returns a list of ranking metrics for model.compile().

Example usage:

pipeline = BasicPipeline(model, train_data, valid_data)
metrics = pipeline.build_metrics()

Returns
A list or a dict of tf.keras.metrics.Metrics.

build_weighted_metrics

View source

Returns a list of weighted ranking metrics for model.compile.

Example usage:

pipeline = BasicPipeline(model, train_data, valid_data)
weighted_metrics = pipeline.build_weighted_metrics()

Returns
A list or a dict of tf.keras.metrics.Metrics.

export_saved_model

View source

Exports the trained model with signatures.

Example usage:

model_builder = ModelBuilder(...)
dataset_builder = DatasetBuilder(...)
hparams = PipelineHparams(...)
pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)
pipeline.export_saved_model(model_builder.build(), 'saved_model/')

Args
model Model to be saved.
export_to Specifies the directory the model is be exported to.
checkpoint If given, export the model with weights from this checkpoint.

train_and_validate

View source

Main function to train the model with TPU strategy.

Example usage:

context_feature_spec = {}
example_feature_spec = {
    "example_feature_1": tf.io.FixedLenFeature(
        shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec = {
    "utility": tf.io.FixedLenFeature(
        shape=(1,), dtype=tf.float32, default_value=0.0)
}
dataset_hparams = DatasetHparams(
    train_input_pattern="train.dat",
    valid_input_pattern="valid.dat",
    train_batch_size=128,
    valid_batch_size=128)
pipeline_hparams = pipeline.PipelineHparams(
    model_dir="model/",
    num_epochs=2,
    steps_per_epoch=5,
    validation_steps=2,
    learning_rate=0.01,
    loss="softmax_loss")
model_builder = SimpleModelBuilder(
    context_feature_spec, example_feature_spec, mask_feature_name)
dataset_builder = SimpleDatasetBuilder(
    context_feature_spec,
    example_feature_spec,
    mask_feature_name,
    label_spec,
    dataset_hparams)
pipeline = BasicModelFitPipeline(
    model_builder, dataset_builder, pipeline_hparams)
pipeline.train_and_validate(verbose=1)

Args
verbose An int for the verbosity level.