View source on GitHub |
Decorator to override default implementation for binary elementwise assert APIs.
tf.experimental.dispatch_for_binary_elementwise_assert_apis(
x_type, y_type
)
The decorated function (known as the "elementwise assert handler")
overrides the default implementation for any binary elementwise assert API
whenever the value for the first two arguments (typically named x
and y
)
match the specified type annotations. The handler is called with two
arguments:
elementwise_assert_handler(assert_func, x, y)
Where x
and y
are the first two arguments to the binary elementwise assert
operation, and assert_func
is a TensorFlow function that takes two
parameters and performs the elementwise assert operation (e.g.,
tf.debugging.assert_equal
).
The following example shows how this decorator can be used to update all
binary elementwise assert operations to handle a MaskedTensor
type:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
@dispatch_for_binary_elementwise_assert_apis(MaskedTensor, MaskedTensor)
def binary_elementwise_assert_api_handler(assert_func, x, y):
merged_mask = tf.logical_and(x.mask, y.mask)
selected_x_values = tf.boolean_mask(x.values, merged_mask)
selected_y_values = tf.boolean_mask(y.values, merged_mask)
assert_func(selected_x_values, selected_y_values)
a = MaskedTensor([1, 1, 0, 1, 1], [False, False, True, True, True])
b = MaskedTensor([2, 2, 0, 2, 2], [True, True, True, False, False])
tf.debugging.assert_equal(a, b) # assert passed; no exception was thrown
a = MaskedTensor([1, 1, 1, 1, 1], [True, True, True, True, True])
b = MaskedTensor([0, 0, 0, 0, 2], [True, True, True, True, True])
tf.debugging.assert_greater(a, b)
Traceback (most recent call last):
InvalidArgumentError: Condition x > y did not hold.
Args | |
---|---|
x_type
|
A type annotation indicating when the api handler should be called. |
y_type
|
A type annotation indicating when the api handler should be called. |
Returns | |
---|---|
A decorator. |
Registered APIs
The binary elementwise assert APIs are:
tf.compat.v1.debugging.assert_equal(x, y, data, summarize, message, name)
tf.compat.v1.debugging.assert_greater(x, y, data, summarize, message, name)
tf.compat.v1.debugging.assert_greater_equal(x, y, data, summarize, message, name)
tf.compat.v1.debugging.assert_less(x, y, data, summarize, message, name)
tf.compat.v1.debugging.assert_less_equal(x, y, data, summarize, message, name)
tf.compat.v1.debugging.assert_near(x, y, rtol, atol, data, summarize, message, name)
tf.compat.v1.debugging.assert_none_equal(x, y, data, summarize, message, name)
tf.debugging.assert_equal(x, y, message, summarize, name)
tf.debugging.assert_greater(x, y, message, summarize, name)
tf.debugging.assert_greater_equal(x, y, message, summarize, name)
tf.debugging.assert_less(x, y, message, summarize, name)
tf.debugging.assert_less_equal(x, y, message, summarize, name)
tf.debugging.assert_near(x, y, rtol, atol, message, summarize, name)
tf.debugging.assert_none_equal(x, y, summarize, message, name)