![]() |
![]() |
![]() |
![]() |
TensorFlow Lite Authoring API provides a way to maintain your tf.function
models compatibile with TensorFlow Lite.
Setup
import tensorflow as tf
2023-09-02 11:13:37.714837: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-09-02 11:13:37.714883: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-09-02 11:13:37.714913: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
TensorFlow to TensorFlow Lite compatibility issue
If you want to use your TF model on devices, you need to convert it to a TFLite model to use it from TFLite interpreter. During the conversion, you might encounter a compatibility error because of unsupported TensorFlow ops by the TFLite builtin op set.
This is a kind of annoying issue. How can you detect it earlier like the model authoring time?
Note that the following code will fail on the converter.convert()
call.
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")
result = [1.]
# Convert the tf.function
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[f.get_concrete_function()], f)
try:
fb_model = converter.convert()
except Exception as e:
print(f"Got an exception: {e}")
Got an exception: /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:694:1: error: 'tf.Cosh' op is neither a custom op nor a flex op self._concrete_variable_creation_fn = tracing_compilation.trace_function( ^ /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:178:1: note: called from concrete_function = _maybe_define_function( ^ /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:284:1: note: called from concrete_function = _create_concrete_function( ^ /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:308:1: note: called from traced_func_graph = func_graph_module.func_graph_from_py_func( ^ /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:1059:1: note: called from func_outputs = python_func(*func_args, **func_kwargs) ^ /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:597:1: note: called from out = weak_wrapped_fn().__wrapped__(*args, **kwds) ^ /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py:41:1: note: called from return api.converted_call( ^ /tmpfs/tmp/ipykernel_40871/885400331.py:5:1: note: called from /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py:2611:1: note: called from _, _, _op, _outputs = _op_def_library._apply_op_helper( ^ /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py:796:1: note: called from op = g._create_op_internal(op_type_name, inputs, dtypes=None, ^ /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:694:1: note: Error code: ERROR_NEEDS_FLEX_OPS self._concrete_variable_creation_fn = tracing_compilation.trace_function( ^ <unknown>:0: error: failed while converting: 'main': Some ops are not supported by the native TFLite runtime, you can enable TF kernels fallback using TF Select. See instructions: https://www.tensorflow.org/lite/guide/ops_select TF Select ops: Cosh Details: tf.Cosh(tensor<?xf32>) -> (tensor<?xf32>) : {device = ""} 2023-09-02 11:13:42.051725: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2023-09-02 11:13:42.051754: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. loc(fused["Cosh:", callsite("Cosh"("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py":694:1) at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":178:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":284:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":308:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py":1059:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py":597:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py":41:1 at callsite("/tmpfs/tmp/ipykernel_40871/885400331.py":5:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py":2611:1 at "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py":796:1)))))))))]): error: 'tf.Cosh' op is neither a custom op nor a flex op error: failed while converting: 'main': Some ops are not supported by the native TFLite runtime, you can enable TF kernels fallback using TF Select. See instructions: https://www.tensorflow.org/lite/guide/ops_select TF Select ops: Cosh Details: tf.Cosh(tensor<?xf32>) -> (tensor<?xf32>) : {device = ""}
Simple Target Aware Authoring usage
We introduced Authoring API to detect the TensorFlow Lite compatibility issue during the model authoring time.
You just need to add @tf.lite.experimental.authoring.compatible
decorator to wrap your tf.function
model to check TFLite compatibility.
After this, the compatibility will be checked automatically when you evaluate your model.
@tf.lite.experimental.authoring.compatible
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")
COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model conversion for TensorFlow Lite. https://www.tensorflow.org/lite/guide/ops_select Op: tf.Cosh - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py:796 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py:2611 - /tmpfs/tmp/ipykernel_40871/885400331.py:5 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py:41 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:597 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:1059 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:308 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:284 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:178 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:694 result = [1.] 2023-09-02 11:13:42.143697: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2023-09-02 11:13:42.143728: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. loc(fused["Cosh:", callsite("Cosh"("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py":694:1) at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":178:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":284:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":308:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py":1059:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py":597:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py":41:1 at callsite("/tmpfs/tmp/ipykernel_40871/885400331.py":5:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py":2611:1 at "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py":796:1)))))))))]): error: 'tf.Cosh' op is neither a custom op nor a flex op error: failed while converting: 'main': Some ops are not supported by the native TFLite runtime, you can enable TF kernels fallback using TF Select. See instructions: https://www.tensorflow.org/lite/guide/ops_select TF Select ops: Cosh Details: tf.Cosh(tensor<?xf32>) -> (tensor<?xf32>) : {device = ""}
If any TensorFlow Lite compatibility issue is found, it will show COMPATIBILITY WARNING
or COMPATIBILITY ERROR
with the exact location of the problematic op. In this example, it shows the location of tf.Cosh
op in your tf.function model.
You can also check the compatiblity log with the <function_name>.get_compatibility_log()
method.
compatibility_log = '\n'.join(f.get_compatibility_log())
print (f"compatibility_log = {compatibility_log}")
compatibility_log = COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model conversion for TensorFlow Lite. https://www.tensorflow.org/lite/guide/ops_select Op: tf.Cosh - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py:796 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py:2611 - /tmpfs/tmp/ipykernel_40871/885400331.py:5 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py:41 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:597 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:1059 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:308 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:284 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:178 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:694
Raise an exception for an incompatibility
You can provide an option to the @tf.lite.experimental.authoring.compatible
decorator. The raise_exception
option gives you an exception when you're trying to evaluate the decorated model.
@tf.lite.experimental.authoring.compatible(raise_exception=True)
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
# Evaluate the tf.function
try:
result = f(tf.constant([0.0]))
print (f"result = {result}")
except Exception as e:
print(f"Got an exception: {e}")
COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model conversion for TensorFlow Lite. https://www.tensorflow.org/lite/guide/ops_select Op: tf.Cosh - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py:796 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py:2611 - /tmpfs/tmp/ipykernel_40871/885400331.py:5 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py:41 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:597 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:1059 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:308 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:284 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:178 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:694 Got an exception: CompatibilityException at <tensorflow.python.eager.polymorphic_function.polymorphic_function.Function object at 0x7f0fce9a66a0> 2023-09-02 11:13:42.249396: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2023-09-02 11:13:42.249436: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. loc(fused["Cosh:", callsite("Cosh"("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py":694:1) at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":178:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":284:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":308:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py":1059:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py":597:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py":41:1 at callsite("/tmpfs/tmp/ipykernel_40871/885400331.py":5:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py":2611:1 at "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py":796:1)))))))))]): error: 'tf.Cosh' op is neither a custom op nor a flex op error: failed while converting: 'main': Some ops are not supported by the native TFLite runtime, you can enable TF kernels fallback using TF Select. See instructions: https://www.tensorflow.org/lite/guide/ops_select TF Select ops: Cosh Details: tf.Cosh(tensor<?xf32>) -> (tensor<?xf32>) : {device = ""}
Specifying "Select TF ops" usage
If you're already aware of Select TF ops usage, you can tell this to the Authoring API by setting converter_target_spec
. It's the same tf.lite.TargetSpec object you'll use it for tf.lite.TFLiteConverter API.
target_spec = tf.lite.TargetSpec()
target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]
@tf.lite.experimental.authoring.compatible(converter_target_spec=target_spec, raise_exception=True)
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")
result = [1.] 2023-09-02 11:13:42.358160: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2023-09-02 11:13:42.358195: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. 2023-09-02 11:13:42.371078: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:2178] TFLite interpreter needs to link Flex delegate in order to run the model since it contains the following Select TFop(s): Flex ops: FlexCosh Details: tf.Cosh(tensor<?xf32>) -> (tensor<?xf32>) : {device = ""} See instructions: https://www.tensorflow.org/lite/guide/ops_select
Checking GPU compatibility
If you want to ensure your model is compatibile with GPU delegate of TensorFlow Lite, you can set experimental_supported_backends
of tf.lite.TargetSpec.
The following example shows how to ensure GPU delegate compatibility of your model. Note that this model has compatibility issues since it uses a 2D tensor with tf.slice operator and unsupported tf.cosh operator. You'll see two COMPATIBILITY WARNING
with the location information.
target_spec = tf.lite.TargetSpec()
target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]
target_spec.experimental_supported_backends = ["GPU"]
@tf.lite.experimental.authoring.compatible(converter_target_spec=target_spec)
@tf.function(input_signature=[
tf.TensorSpec(shape=[4, 4], dtype=tf.float32)
])
def func(x):
y = tf.cosh(x)
return y + tf.slice(x, [1, 1], [1, 1])
result = func(tf.ones(shape=(4,4), dtype=tf.float32))
'tfl.slice' op is not GPU compatible: SLICE supports for 3 or 4 dimensional tensors only, but node has 2 dimensional tensors. - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py:1260 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py:150 - /tmpfs/tmp/ipykernel_40871/3833138856.py:13 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py:41 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:597 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:1059 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:308 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:284 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:178 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:694 'tf.Cosh' op is not GPU compatible: Not supported custom op FlexCosh - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py:796 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py:2611 - /tmpfs/tmp/ipykernel_40871/3833138856.py:12 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py:41 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:597 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:1059 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:308 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:284 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py:178 - /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:694 COMPATIBILITY WARNING: op 'tf.Cosh, tfl.slice' aren't compatible with TensorFlow Lite GPU delegate. https://www.tensorflow.org/lite/performance/gpu 2023-09-02 11:13:42.495257: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2023-09-02 11:13:42.495288: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. 2023-09-02 11:13:42.509206: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:2178] TFLite interpreter needs to link Flex delegate in order to run the model since it contains the following Select TFop(s): Flex ops: FlexCosh Details: tf.Cosh(tensor<4x4xf32>) -> (tensor<4x4xf32>) : {device = ""} See instructions: https://www.tensorflow.org/lite/guide/ops_select loc(fused["Cosh:", callsite("Cosh"("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py":694:1) at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":178:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":284:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":308:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py":1059:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py":597:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py":41:1 at callsite("/tmpfs/tmp/ipykernel_40871/3833138856.py":12:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/ops/gen_math_ops.py":2611:1 at "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py":796:1)))))))))]): error: 'tf.Cosh' op is not GPU compatible: Not supported custom op FlexCosh loc(fused["Slice:", callsite("Slice"("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py":694:1) at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":178:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":284:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py":308:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py":1059:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py":597:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py":41:1 at callsite("/tmpfs/tmp/ipykernel_40871/3833138856.py":13:1 at callsite("/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py":150:1 at "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py":1260:1)))))))))]): error: 'tfl.slice' op is not GPU compatible: SLICE supports for 3 or 4 dimensional tensors only, but node has 2 dimensional tensors.
Read more
For more information, please refer to: