tfl.configs.RegularizerConfig

View source on GitHub

Regularizer configuration for TFL canned estimators.

Used in the notebooks

Used in the tutorials

Regularizers can either be applied to specific features, or can be applied globally to all features or lattices.

  • Calibrator regularizers:

    These regularizers are applied to PWL calibration layers.

  • Lattice regularizers:

    These regularizers are applied to lattice layers.

    • 'laplacian': Creates an instance of tfl.lattice_layer.LaplacianRegularizer. Laplacian regularizers penalize the difference between adjacent vertices in multi-cell lattice, resulting in a flatter lattice function.
    • 'torsion': Creates an instance of tfl.lattice_layer.TorsionRegularizer. Torsion regularizers penalizes how much the lattice function twists from side-to-side, a non-linear interactions in each 2 x 2 cell. Using this regularization results in a more linear lattice function.

Examples:

model_config = tfl.configs.CalibratedLatticeConfig(
    feature_configs=[
        tfl.configs.FeatureConfig(
            name='age',
            lattice_size=3,
            # Per feature regularization.
            regularizer_configs=[
                tfl.configs.RegularizerConfig(name='calib_hessian', l2=1e-4),
            ],
        ),
        tfl.configs.FeatureConfig(
            name='thal',
            # Partial monotonicity:
            # output(normal) <= output(fixed)
            # output(normal) <= output(reversible)
            monotonicity=[('normal', 'fixed'), ('normal', 'reversible')],
        ),
    ],
    # Global regularizers
    regularizer_configs=[
        # Torsion regularizer applied to the lattice to make it more linear.
        configs.RegularizerConfig(name='torsion', l2=1e-4),
        # Globally defined calibration regularizer is applied to all features.
        configs.RegularizerConfig(name='calib_hessian', l2=1e-4),
    ])
feature_analysis_input_fn = create_input_fn(num_epochs=1, ...)
train_input_fn = create_input_fn(num_epochs=100, ...)
estimator = tfl.estimators.CannedClassifier(
    feature_columns=feature_columns,
    model_config=model_config,
    feature_analysis_input_fn=feature_analysis_input_fn)
estimator.train(input_fn=train_input_fn)

name The name of the regularizer.
l1 l1 regularization amount.
l2 l2 regularization amount.

Methods

deserialize_nested_configs

View source

Returns a deserialized configuration dictionary.

from_config

View source

get_config

View source

Returns a configuration dictionary.