Has the ability to load and apply an ML model.
tfx_bsl.public.beam.run_inference.ModelHandler()
Methods
batch_elements_kwargs
batch_elements_kwargs() -> Mapping[str, Any]
Returns: kwargs suitable for beam.BatchElements.
get_metrics_namespace
get_metrics_namespace() -> str
Returns: A namespace for metrics collected by the RunInference transform.
get_num_bytes
get_num_bytes(
batch: Sequence[ExampleT]
) -> int
Returns: The number of bytes of data for a batch.
get_postprocess_fns
get_postprocess_fns() -> Iterable[Callable[[Any], Any]]
Gets all postprocessing functions to be run after inference. Functions are in order that they should be applied.
get_preprocess_fns
get_preprocess_fns() -> Iterable[Callable[[Any], Any]]
Gets all preprocessing functions to be run before batching/inference. Functions are in order that they should be applied.
get_resource_hints
get_resource_hints() -> dict
Returns: Resource hints for the transform.
load_model
load_model() -> ModelT
Loads and initializes a model for processing.
run_inference
run_inference(
batch: Sequence[ExampleT],
model: ModelT,
inference_args: Optional[Dict[str, Any]] = None
) -> Iterable[PredictionT]
Runs inferences on a batch of examples.
Args | |
---|---|
batch
|
A sequence of examples or features. |
model
|
The model used to make inferences. |
inference_args
|
Extra arguments for models whose inference call requires extra parameters. |
Returns | |
---|---|
An Iterable of Predictions. |
set_environment_vars
set_environment_vars()
Sets environment variables using a dictionary provided via kwargs. Keys are the env variable name, and values are the env variable value. Child ModelHandler classes should set _env_vars via kwargs in init, or else call super().init().
share_model_across_processes
share_model_across_processes() -> bool
Returns a boolean representing whether or not a model should be shared across multiple processes instead of being loaded per process. This is primary useful for large models that can't fit multiple copies in memory. Multi-process support may vary by runner, but this will fallback to loading per process as necessary. See https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html
update_model_path
update_model_path(
model_path: Optional[str] = None
)
Update the model paths produced by side inputs.
validate_inference_args
validate_inference_args(
inference_args: Optional[Dict[str, Any]]
)
Validates inference_args passed in the inference call.
Because most frameworks do not need extra arguments in their predict() call, the default behavior is to error out if inference_args are present.
with_postprocess_fn
with_postprocess_fn(
fn: Callable[[PredictionT], PostProcessT]
) -> 'ModelHandler[ExampleT, PostProcessT, ModelT, PostProcessT]'
Returns a new ModelHandler with a postprocessing function associated with it. The postprocessing function will be run after inference and should map the base ModelHandler's output type to your desired output type. If you apply multiple postprocessing functions, they will be run on your original inference result in order from first applied to last applied.
with_preprocess_fn
with_preprocess_fn(
fn: Callable[[PreProcessT], ExampleT]
) -> 'ModelHandler[PreProcessT, PredictionT, ModelT, PreProcessT]'
Returns a new ModelHandler with a preprocessing function associated with it. The preprocessing function will be run before batching/inference and should map your input PCollection to the base ModelHandler's input type. If you apply multiple preprocessing functions, they will be run on your original PCollection in order from last applied to first applied.