Module: tfl.premade_lib

Implementation of algorithms required for premade models.

Classes

class LayerOutputRange: Enum to indicate the output range based on the input of the next layers.

Functions

build_aggregation_layer(...): Creates an aggregation layer using the given calibrated lattice models.

build_calibrated_lattice_ensemble_layer(...): Creates a calibration layer followed by a lattice ensemble layer.

build_calibration_layers(...): Creates a calibration layer for submodels as list of list of features.

build_input_layer(...): Creates a mapping from feature name to tf.keras.Input.

build_lattice_ensemble_layer(...): Creates an ensemble of tfl.layers.Lattice layers.

build_lattice_layer(...): Creates a tfl.layers.Lattice layer.

build_linear_combination_layer(...): Creates a tfl.layers.Linear layer initialized to be an average.

build_linear_layer(...): Creates a tfl.layers.Linear layer initialized to be an average.

build_multi_unit_calibration_layers(...): Creates a mapping from feature names to calibration outputs.

build_output_calibration_layer(...): Creates a monotonic output calibration layer with inputs range [0, 1].

build_rtl_layer(...): Creates a tfl.layers.RTL layer.

construct_prefitting_model_config(...): Constructs a model config for a prefitting model for crystal extraction.

set_categorical_monotonicities(...): Maps categorical monotonicities to indices based on specified vocab list.

set_crystals_lattice_ensemble(...): Extracts crystals from a prefitting model and finalizes model_config.

set_random_lattice_ensemble(...): Sets random lattice ensemble in the given model_config.

verify_config(...): Verifies that the model_config and feature_configs are fully specified.

AGGREGATION_LAYER_NAME 'tfl_aggregation'
CALIB_LAYER_NAME 'tfl_calib'
CALIB_PASSTHROUGH_NAME 'tfl_calib_passthrough'
INPUT_LAYER_NAME 'tfl_input'
LATTICE_LAYER_NAME 'tfl_lattice'
LINEAR_LAYER_NAME 'tfl_linear'
OUTPUT_CALIB_LAYER_NAME 'tfl_output_calib'
OUTPUT_LINEAR_COMBINATION_LAYER_NAME 'tfl_output_linear_combination'
RTL_INPUT_NAME 'tfl_rtl_input'
RTL_LAYER_NAME 'tfl_rtl'