Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge

tfr.keras.pipeline.MultiLabelDatasetBuilder

Builds datasets for multi-task training.

Inherits From: BaseDatasetBuilder, AbstractDatasetBuilder

This supports a single data sets with multiple labels formed in a dict. The case where we have multiple datasets is not handled in the current code yet. We can consider to extend the dataset builder when the use case comes out.

Example usage:

context_feature_spec = {}
example_feature_spec = {
    "example_feature_1": tf.io.FixedLenFeature(
        shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec_tuple = ("utility",
                    tf.io.FixedLenFeature(
                        shape=(1,),
                        dtype=tf.float32,
                        default_value=_PADDING_LABEL))
label_spec = {"task1": label_spec_tuple, "task2": label_spec_tuple}
weight_spec = ("weight",
               tf.io.FixedLenFeature(
                   shape=(1,), dtype=tf.float32, default_value=1.))
dataset_hparams = DatasetHparams(
    train_input_pattern="train.dat",
    valid_input_pattern="valid.dat",
    train_batch_size=128,
    valid_batch_size=128)
dataset_builder = MultiLabelDatasetBuilder(
    context_feature_spec,
    example_feature_spec,
    mask_feature_name,
    label_spec,
    dataset_hparams,
    sample_weight_spec=weight_spec)

context_feature_spec Maps context (aka, query) names to feature specs.
example_feature_spec Maps example (aka, document) names to feature specs.
mask_feature_name If set, populates the feature dictionary with this name and the coresponding value is a tf.bool Tensor of shape [batch_size, list_size] indicating the actual example is padded or not.
label_spec A dict that maps task names to label specs. Each of the latter have a label name and a tf.io.FixedLenFeature spec.
hparams A dict containing model hyperparameters.
sample_weight_spec Feature spec for per-example weight.

Methods

build_signatures

View source

See AbstractDatasetBuilder.

build_train_dataset

View source

See AbstractDatasetBuilder.

build_valid_dataset

View source

See AbstractDatasetBuilder.