ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


Category crossing layer.

Inherits From: Layer, Module

This layer concatenates multiple categorical inputs into a single categorical output (similar to Cartesian product). The output dtype is string.


inp_1 = ['a', 'b', 'c']
inp_2 = ['d', 'e', 'f']
layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing()
layer([inp_1, inp_2])
<tf.Tensor: shape=(3, 1), dtype=string, numpy=
         [b'c_X_f']], dtype=object)>
inp_1 = ['a', 'b', 'c']
inp_2 = ['d', 'e', 'f']
layer = tf.keras.layers.experimental.preprocessing.CategoryCrossing(
layer([inp_1, inp_2])
<tf.Tensor: shape=(3, 1), dtype=string, numpy=
         [b'c-f']], dtype=object)>

depth depth of input crossing. By default None, all inputs are crossed into one output. It can also be an int or tuple/list of ints. Passing an integer will create combinations of crossed outputs with depth up to that integer, i.e., [1, 2, ..., depth), and passing a tuple of integers will create crossed outputs with depth for the specified values in the tuple, i.e., depth=(N1, N2) will create all possible crossed outputs with depth equal to N1 or N2. Passing None means a single crossed output with all inputs. For example, with inputs a, b and c, depth=2 means the output will be [a;b;c;cross(a, b);cross(bc);cross(ca)].
separator A string added between each input being joined. Defaults to 'X'.
name Name to give to the layer.
**kwargs Keyword arguments to construct a layer.

Input shape: a list of string or int tensors or sparse tensors of shape [batch_size, d1, ..., dm]

Output shape: a single string or int tensor or sparse tensor of shape [batch_size, d1, ..., dm]

If any input is RaggedTensor, the output is RaggedTensor. Else, if any input is SparseTensor, the output is SparseTensor. Otherwise, the output is Tensor.

Example: (depth=None) If the layer receives three inputs: a=[[1], [4]], b=[[2], [5]], c=[[3], [6]] the output will be a string tensor: [[b'1_X_2_X_3'], [b'4_X_5_X_6']]

Example: (depth is an integer) With the same input above, and if depth=2, the output will be a list of 6 string tensors: [[b'1'], [b'4']] [[b'2'], [b'5']] [[b'3'], [b'6']] [[b'1_X_2'], [b'4_X_5']], [[b'2_X_3'], [b'5_X_6']], [[b'3_X_1'], [b'6_