tfl.configs.AggregateFunctionConfig

Config for aggregate function learning model.

Used in the notebooks

Used in the tutorials

An aggregate function learning model applies piecewise-linear and categorical calibration on the ragged input features, followed by an aggregation layer that aggregates the calibrated inputs. Lastly a lattice model and an optional output piecewise-linear calibration are applied.

Example:

model_config = tfl.configs.AggregateFunctionConfig(
    feature_configs=[...],
)
model = tfl.premade.AggregateFunction(model_config)
model.compile(...)
model.fit(...)
model.evaluate(...)

feature_configs A list of tfl.configs.FeatureConfig instances that specify configurations for each feature.
regularizer_configs A list of tfl.configs.RegularizerConfig instances that apply global regularization.
middle_dimension The number of calibrated lattices that are applied to each block. The outputs of these lattices are then averaged over the blocks, and the middle_dimension resulting numbers are then passed into the "middle" calibrated lattice. This middle lattice therefore has input dimension equal to middle_dimension.
middle_lattice_size Size of each of the middle_lattice dimensions.
middle_calibration If a piecewise-linear calibration should be used on the inputs to the middle lattice.
middle_calibration_num_keypoints Number of keypoints to use for the middle piecewise-linear calibration.
middle_calibration_input_keypoints_type One of "fixed" or "learned_interior". If "learned_interior", keypoints are initialized to the values in pwl_calibration_input_keypoints but then allowed to vary during training, with the exception of the first and last keypoint location which are fixed.
middle_monotonicity Specifies if the middle calibrators should be monotonic, using 'increasing' or 1 to indicate increasing monotonicity, 'decreasing' or -1 to indicate decreasing monotonicity, and 'none' or 0 to indicate no monotonicity constraints.
middle_lattice_interpolation One of 'hypercube' or 'simplex'. For a d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas 'simplex' uses d+1 parameters and thus scales better. For details see tfl.lattice_lib.evaluate_with_simplex_interpolation and tfl.lattice_lib.evaluate_with_hypercube_interpolation.
aggregation_lattice_interpolation One of 'hypercube' or 'simplex'. For a d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas 'simplex' uses d+1 parameters and thus scales better. For details see tfl.lattice_lib.evaluate_with_simplex_interpolation and tfl.lattice_lib.evaluate_with_hypercube_interpolation.
output_min Lower bound constraint on the output of the model.
output_max Upper bound constraint on the output of the model.
output_calibration If a piecewise-linear calibration should be used on the output of the lattice.
output_calibration_num_keypoints Number of keypoints to use for the output piecewise-linear calibration.
output_initialization The initial values to setup for the output of the model. When using output calibration, these values are used to initialize the output keypoints of the output piecewise-linear calibration. Otherwise the lattice parameters will be setup to form a linear function in the range of output_initialization. It can be one of:

  • String 'uniform': Output is initliazed uniformly in label range.
  • A list of numbers: To be used for initialization of the output lattice or output calibrator.
output_calibration_input_keypoints_type One of "fixed" or "learned_interior". If "learned_interior", keypoints are initialized to the values in pwl_calibration_input_keypoints but then allowed to vary during training, with the exception of the first and last keypoint location which are fixed.

Methods

deserialize_nested_configs

View source

Returns a deserialized configuration dictionary.

feature_config_by_name

View source

Returns existing or default FeatureConfig with the given name.

from_config

View source

get_config

View source

Returns a configuration dictionary.

regularizer_config_by_name

View source

Returns existing or default RegularizerConfig with the given name.