tf.experimental.dispatch_for_binary_elementwise_assert_apis

Decorator to override default implementation for binary elementwise assert APIs.

The decorated function (known as the "elementwise assert handler") overrides the default implementation for any binary elementwise assert API whenever the value for the first two arguments (typically named x and y) match the specified type annotations. The handler is called with two arguments:

elementwise_assert_handler(assert_func, x, y)

Where x and y are the first two arguments to the binary elementwise assert operation, and assert_func is a TensorFlow function that takes two parameters and performs the elementwise assert operation (e.g., tf.debugging.assert_equal).

The following example shows how this decorator can be used to update all binary elementwise assert operations to handle a MaskedTensor type:

class MaskedTensor(tf.experimental.ExtensionType):
  values: tf.Tensor
  mask: tf.Tensor
@dispatch_for_binary_elementwise_assert_apis(MaskedTensor, MaskedTensor)
def binary_elementwise_assert_api_handler(assert_func, x, y):
  merged_mask = tf.logical_and(x.mask, y.mask)
  selected_x_values = tf.boolean_mask(x.values, merged_mask)
  selected_y_values = tf.boolean_mask(y.values, merged_mask)
  assert_func(selected_x_values, selected_y_values)
a = MaskedTensor([1, 1, 0, 1, 1], [False, False, True, True, True])
b = MaskedTensor([2, 2, 0, 2, 2], [True, True, True, False, False])
tf.debugging.assert_equal(a, b) # assert passed; no exception was thrown
a = MaskedTensor([1, 1, 1, 1, 1], [True, True, True, True, True])
b = MaskedTensor([0, 0, 0, 0, 2], [True, True, True, True, True])
tf.debugging.assert_greater(a, b)
Traceback (most recent call last):

InvalidArgumentError: Condition x > y did not hold.

x_type A type annotation indicating when the api handler should be called.
y_type A type annotation indicating when the api handler should be called.

A decorator.

Registered APIs

The binary elementwise assert APIs are: