tfr.extension.pipeline.RankingPipeline

Class to set up the input, train and eval processes for a TF Ranking model.

An example use case is provided below:

import tensorflow as tf
import tensorflow_ranking as tfr

context_feature_columns = {
  "c1": tf.feature_column.numeric_column("c1", shape=(1,))
}
example_feature_columns = {
  "e1": tf.feature_column.numeric_column("e1", shape=(1,))
}

hparams = dict(
      train_input_pattern="/path/to/train/files",
      eval_input_pattern="/path/to/eval/files",
      train_batch_size=8,
      eval_batch_size=8,
      checkpoint_secs=120,
      num_checkpoints=1000,
      num_train_steps=10000,
      num_eval_steps=100,
      loss="softmax_loss",
      list_size=10,
      listwise_inference=False,
      convert_labels_to_binary=False,
      model_dir="/path/to/your/model/directory")

# See `tensorflow_ranking.estimator` for details about creating an estimator.
estimator = <create your own estimator>

ranking_pipeline = tfr.ext.pipeline.RankingPipeline(
      context_feature_columns,
      example_feature_columns,
      hparams,
      estimator=estimator,
      label_feature_name="relevance",
      label_feature_type=tf.int64)
ranking_pipeline.train_and_eval()

  • pass best_exporter_metric and best_exporter_metric_higher_better for different model export strategies.
  • pass dataset_reader for reading different tf.Datasets. We recommend using TFRecord files and storing your data in tfr.data.ELWC format.

If you want to further customize certain RankingPipeline behaviors, please create a subclass of RankingPipeline, and overwrite related functions. We recommend only overwriting the following functions:

  • _make_dataset which builds the tf.dataset for a tf-ranking model.
  • _make_serving_input_fn that defines the input function for serving.
  • _export_strategies if you have more advanced needs for model exporting.

For example, if you want to remove the best exporters, you may overwrite:

class NoBestExporterRankingPipeline(tfr.ext.pipeline.RankingPipeline):
  def _export_strategies(self, event_file_pattern):
    del event_file_pattern
    latest_exporter = tf.estimator.LatestExporter(
        "latest_model",
        serving_input_receiver_fn=self._make_serving_input_fn())
    return [latest_exporter]

ranking_pipeline = NoBestExporterRankingPipeline(
      context_feature_columns,
      example_feature_columns,
      hparams
      estimator=estimator)
ranking_pipeline.train_and_eval()

if you want to customize your dataset reading behaviors, you may overwrite:

class CustomizedDatasetRankingPipeline(tfr.ext.pipeline.RankingPipeline):
  def _make_dataset(self,
                    batch_size,
                    list_size,
                    input_pattern,
                    randomize_input=True,
                    num_epochs=None):
    # Creates your own dataset, plese follow `tfr.data.build_ranking_dataset`.
    dataset = build_my_own_ranking_dataset(...)
    ...
    return dataset.map(self._features_and_labels)

ranking_pipeline = CustomizedDatasetRankingPipeline(
      context_feature_columns,
      example_feature_columns,
      hparams
      estimator=estimator)
ranking_pipeline.train_and_eval()

context_feature_columns (dict) Context (aka, query) feature columns.
example_feature_columns (dict) Example (aka, document) feature columns.
hparams (dict) A dict containing model hyperparameters.
estimator (Estimator) An Estimator instance for model train and eval.
label_feature_name (str) The name of the label feature.
label_feature_type (tf.dtype) The value type of the label feature.
dataset_reader (tf.Dataset) The dataset format for the input files.
best_exporter_metric (str) Metric key for exporting the best model. If None, exports the model with the minimal loss value.
best_exporter_metric_higher_better (bool) If a higher metric is better. This is only used if best_exporter_metric is not None.
size_feature_name (str) If set, populates the feature dictionary with this name and the coresponding value is a tf.int32 Tensor of shape [batch_size] indicating the actual sizes of the example lists before padding and truncation. If None, which is default, this feature is not generated.

Methods

train_and_eval

View source

Launches train and evaluation jobs locally.