A decorator that registers a TFP object as composite-friendly.
tfp.experimental.register_composite()
This registration is not required to call as_composite
on instances
of a given distribution (or bijector or other TFP object), but it is
required if a SavedModel
with functions accepting or returning composite
wrappers of this object will be loaded in python (without having called
as_composite
already).
Example:
class MyDistribution(tfp.distributions.Distribution):
...
# This will fail to load.
model = tf.saved_model.load(
'/path/to/sm_with_funcs_returning_composite_tensor_MyDistribution')
Instead:
@tfp.experimental.register_composite
class MyDistribution(tfp.distributions.Distribution):
...
# This will load.
model = tf.saved_model.load(
'/path/to/sm_with_funcs_returning_composite_tensor_MyDistribution')
Args |
cls
|
A subclass of Distribution .
|
Returns |
The input, with the side-effect of registering it as a composite-friendly
distribution.
|
Raises |
TypeError
|
If cls does not have _composite_tensor_params, or if
registration fails (cls is not convertible).
|
NotImplementedError
|
If registration fails (cls is not convertible).
|