tfl.configs.AggregateFunctionConfig

Stay organized with collections Save and categorize content based on your preferences.

Config for aggregate function learning model.

Used in the notebooks

Used in the tutorials

An aggregate function learning model applies piecewise-linear and categorical calibration on the ragged input features, followed by an aggregation layer that aggregates the calibrated inputs. Lastly a lattice model and an optional output piecewise-linear calibration are applied.

Example:

model_config = tfl.configs.AggregateFunctionConfig(
    feature_configs=[...],
)
model = tfl.premade.AggregateFunction(model_config)
model.compile(...)
model.fit(...)
model.evaluate(...)

feature_configs A list of tfl.configs.FeatureConfig instances that specify configurations for each feature.
regularizer_configs A list of tfl.configs.RegularizerConfig instances that apply global regularization.
middle_dimension The number of calibrated lattices that are applied to each block. The outputs of these lattices are then averaged over the blocks, and the middle_dimension resulting numbers are then passed into the "middle" calibrated lattice. This middle lattice therefore has input dimension equal to middle_dimension.
middle_lattice_size Size of each of the middle_lattice dimensions.
middle_calibration If a piecewise-linear calibration should be used on the inputs to the middle lattice.
middle_calibration_num_keypoints Number of keypoints to use for the middle piecewise-linear calibration.
middle_calibration_input_keypoints_type One of "fixed" or "learned_interior". If "learned_interior", keypoints are initialized to the values in pwl_calibration_input_keypoints but then allowed to vary during training, with the exception of the first and last keypoint location which are fixed.
middle_monotonicity Specifies if the middle calibrators should be monotonic, using 'increasing' or 1 to indicate increasing monotonicity, 'decreasing' or -1 to indicate decreasing monotonicity, and 'none' or 0 to indicate no monotonicity constraints.
middle_lattice_interpolation One of 'hypercube' or 'simplex'. For a d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas 'simplex' uses d+1 parameters and thus scales better. For details see tfl.lattice_lib.evaluate_with_simplex_interpolation and tfl.lattice_lib.evaluate_with_hypercube_interpolation.
aggregation_lattice_interpolation One of 'hypercube' or 'simplex'. For a d-dimensional lattice, 'hypercube' interpolates 2^d parameters, whereas 'simplex' uses d+1 parameters and thus scales better. For details see tfl.lattice_lib.evaluate_with_simplex_interpolation and tfl.lattice_lib.evaluate_with_hypercube_interpolation.
output_min Lower bound constraint on the output of the model.
output_max Upper bound constraint on the output of the model.
output_calibration If a piecewise-linear calibration should be used on the output of the lattice.
output_calibration_num_keypoints Number of keypoints to use for the output piecewise-linear calibration.
output_initialization The initial values to setup for the output of the model. When using output calibration, these values are used to initialize the output keypoints of the output piecewise-linear calibration. Otherwise the lattice parameters will be setup to form a linear function in the range of output_initialization. It can be one of:

  • String 'uniform': Output is initliazed uniformly in label range.
  • A list of numbers: To be used for initialization of the output lattice or output calibrator.
output_calibration_input_keypoints_type One of "fixed" or "learned_interior". If "learned_interior", keypoints are initialized to the values in pwl_calibration_input_keypoints but then allowed to vary during training, with the exception of the first and last keypoint location which are fixed.

Methods

deserialize_nested_configs

View source

Returns a deserialized configuration dictionary.

feature_config_by_name

View source

Returns existing or default FeatureConfig with the given name.

from_config

View source

get_config

View source

Returns a configuration dictionary.

regularizer_config_by_name

View source

Returns existing or default RegularizerConfig with the given name.