Driver class to drive model inference with TensorFlow Lite.
A
Interpreter
encapsulates a pre-trained TensorFlow Lite model, in which operations
are executed for model inference.
For example, if a model takes only one input and returns only one output:
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.run(input, output);
}
If a model takes multiple inputs or outputs:
Object[] inputs = {input0, input1, ...
; Mapmap_of_indices_to_outputs = new HashMap<>(); FloatBuffer ith_output = FloatBuffer.allocateDirect(3 * 2 * 4); // Float tensor, shape 3x2x4. ith_output.order(ByteOrder.nativeOrder()); map_of_indices_to_outputs.put(i, ith_output); try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) { interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs); } }
If a model takes or produces string tensors:
String[] input = {"foo", "bar"
; // Input tensor shape is [2].
String[] output = new String[3][2]; // Output tensor shape is [3, 2].
try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
interpreter.runForMultipleInputsOutputs(input, output);
}
}
Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite model with Toco, as are the default shapes of the inputs.
When inputs are provided as (multi-dimensional) arrays, the corresponding input tensor(s) will
be implicitly resized according to that array's shape. When inputs are provided as
ERROR(/Buffer)
types, no implicit resizing is done; the caller must ensure that the
ERROR(/Buffer)
byte size
either matches that of the corresponding tensor, or that they first resize the tensor via
ERROR(/#resizeInput())
. Tensor shape and type information can be obtained via the
Tensor
class,
available via
getInputTensor(int)
and
getOutputTensor(int)
.
WARNING:
Instances of a
Interpreter
is
not
thread-safe. A
Interpreter
owns resources that
must
be explicitly freed by invoking
close()
The TFLite library is built against NDK API 19. It may work for Android API levels below 19, but is not guaranteed.
Nested Classes
class | Interpreter.Options | An options class for controlling runtime interpreter behavior. |
Public Constructors
Interpreter
(File modelFile)
Initializes a
Interpreter
|
|
Interpreter
(File modelFile, int numThreads)
This constructor was deprecated
in API level .
Prefer using the
Interpreter(File, Options)
constructor. This method will
be removed in a future release.
|
|
Interpreter
(File modelFile,
Interpreter.Options
options)
Initializes a
Interpreter
and specifies the number of threads used for inference.
|
|
Interpreter
(ByteBuffer byteBuffer)
Initializes a
Interpreter
with a
ByteBuffer
of a model file.
|
|
Interpreter
(ByteBuffer byteBuffer, int numThreads)
This constructor was deprecated
in API level .
Prefer using the
Interpreter(ByteBuffer, Options)
constructor. This method
will be removed in a future release.
|
|
Interpreter
(MappedByteBuffer mappedByteBuffer)
This constructor was deprecated
in API level .
Prefer using the
Interpreter(ByteBuffer, Options)
constructor. This method
will be removed in a future release.
|
|
Interpreter
(ByteBuffer byteBuffer,
Interpreter.Options
options)
|
Public Methods
void |
allocateTensors
()
Expicitly updates allocations for all tensors, if necessary.
|
void |
close
()
Release resources associated with the
Interpreter
.
|
int |
getInputIndex
(String opName)
Gets index of an input given the op name of the input.
|
Tensor |
getInputTensor
(int inputIndex)
Gets the Tensor associated with the provdied input index.
|
int |
getInputTensorCount
()
Gets the number of input tensors.
|
Long |
getLastNativeInferenceDurationNanoseconds
()
Returns native inference timing.
|
int |
getOutputIndex
(String opName)
Gets index of an output given the op name of the output.
|
Tensor |
getOutputTensor
(int outputIndex)
Gets the Tensor associated with the provdied output index.
|
int |
getOutputTensorCount
()
Gets the number of output Tensors.
|
void |
modifyGraphWithDelegate
(
Delegate
delegate)
Advanced: Modifies the graph with the provided
Delegate
.
|
void |
resetVariableTensors
()
Advanced: Resets all variable tensors to the default value.
|
void |
resizeInput
(int idx, int[] dims, boolean strict)
Resizes idx-th input of the native model to the given dims.
|
void |
resizeInput
(int idx, int[] dims)
Resizes idx-th input of the native model to the given dims.
|
void |
run
(Object input, Object output)
Runs model inference if the model takes only one input, and provides only one output.
|
void |
runForMultipleInputsOutputs
(Object[] inputs, Map<Integer, Object> outputs)
Runs model inference if the model takes multiple inputs, or returns multiple outputs.
|
void |
setCancelled
(boolean cancelled)
Advanced: Interrupts inference in the middle of a call to
run(Object, Object)
.
|
void |
setNumThreads
(int numThreads)
This method was deprecated
in API level .
Prefer using
setNumThreads(int)
directly for controlling thread
multi-threading. This method will be removed in a future release.
|
Inherited Methods
Public Constructors
public Interpreter (File modelFile)
Initializes a
Interpreter
Throws
IllegalArgumentException |
if
modelFile
does not encode a valid TensorFlow Lite
model.
|
---|
public Interpreter (File modelFile, int numThreads)
This constructor was deprecated
in API level
.
Prefer using the
Interpreter(File, Options)
constructor. This method will
be removed in a future release.
Initializes a
Interpreter
and specifies the number of threads used for inference.
public Interpreter (File modelFile, Interpreter.Options options)
Initializes a
Interpreter
and specifies the number of threads used for inference.
Throws
IllegalArgumentException |
if
modelFile
does not encode a valid TensorFlow Lite
model.
|
---|
public Interpreter (ByteBuffer byteBuffer)
Initializes a
Interpreter
with a
ByteBuffer
of a model file.
The ByteBuffer should not be modified after the construction of a
Interpreter
. The
ByteBuffer
can be either a
MappedByteBuffer
that memory-maps a model file, or a
direct
ByteBuffer
of nativeOrder() that contains the bytes content of a model.
Throws
IllegalArgumentException |
if
byteBuffer
is not a
MappedByteBuffer
nor a
direct
ERROR(/Bytebuffer)
of nativeOrder.
|
---|
public Interpreter (ByteBuffer byteBuffer, int numThreads)
This constructor was deprecated
in API level
.
Prefer using the
Interpreter(ByteBuffer, Options)
constructor. This method
will be removed in a future release.
Initializes a
Interpreter
with a
ByteBuffer
of a model file and specifies the
number of threads used for inference.
The ByteBuffer should not be modified after the construction of a
Interpreter
. The
ByteBuffer
can be either a
MappedByteBuffer
that memory-maps a model file, or a
direct
ByteBuffer
of nativeOrder() that contains the bytes content of a model.
public Interpreter (MappedByteBuffer mappedByteBuffer)
This constructor was deprecated
in API level
.
Prefer using the
Interpreter(ByteBuffer, Options)
constructor. This method
will be removed in a future release.
Initializes a
Interpreter
with a
MappedByteBuffer
to the model file.
The
MappedByteBuffer
should remain unchanged after the construction of a
Interpreter
.
public Interpreter (ByteBuffer byteBuffer, Interpreter.Options options)
Initializes a
Interpreter
with a
ByteBuffer
of a model file and a set of custom
ERROR(/#Options)
.
The ByteBuffer should not be modified after the construction of a
Interpreter
. The
ByteBuffer
can be either a
MappedByteBuffer
that memory-maps a model file, or a
direct
ByteBuffer
of nativeOrder() that contains the bytes content of a model.
Throws
IllegalArgumentException |
if
byteBuffer
is not a
MappedByteBuffer
nor a
direct
ERROR(/Bytebuffer)
of nativeOrder.
|
---|
Public Methods
public void allocateTensors ()
Expicitly updates allocations for all tensors, if necessary.
This will propagate shapes and memory allocations for all dependent tensors using the input tensor shape(s) as given.
Note: This call is *purely optional*. Tensor allocation will occur automatically during execution if any input tensors have been resized. This call is most useful in determining the shapes for any output tensors before executing the graph, e.g.,
interpreter.resizeInput(0, new int[]{1, 4, 4, 3
));
interpreter.allocateTensors();
FloatBuffer input = FloatBuffer.allocate(interpreter.getInputTensor(0),numElements());
// Populate inputs...
FloatBuffer output = FloatBuffer.allocate(interpreter.getOutputTensor(0).numElements());
interpreter.run(input, output)
// Process outputs...
}
Throws
IllegalStateException | if the graph's tensors could not be successfully allocated. |
---|
public void close ()
Release resources associated with the
Interpreter
.
public int getInputIndex (String opName)
Gets index of an input given the op name of the input.
Throws
IllegalArgumentException |
if
opName
does not match any input in the model used
to initialize the
Interpreter
.
|
---|
public Tensor getInputTensor (int inputIndex)
Gets the Tensor associated with the provdied input index.
Throws
IllegalArgumentException |
if
inputIndex
is negtive or is not smaller than the
number of model inputs.
|
---|
public int getInputTensorCount ()
Gets the number of input tensors.
public Long getLastNativeInferenceDurationNanoseconds ()
Returns native inference timing.
Throws
IllegalArgumentException |
if the model is not initialized by the
Interpreter
.
|
---|
public int getOutputIndex (String opName)
Gets index of an output given the op name of the output.
Throws
IllegalArgumentException |
if
opName
does not match any output in the model used
to initialize the
Interpreter
.
|
---|
public Tensor getOutputTensor (int outputIndex)
Gets the Tensor associated with the provdied output index.
Note: Output tensor details (e.g., shape) may not be fully populated until after inference
is executed. If you need updated details *before* running inference (e.g., after resizing an
input tensor, which may invalidate output tensor shapes), use
allocateTensors()
to
explicitly trigger allocation and shape propagation. Note that, for graphs with output shapes
that are dependent on input *values*, the output shape may not be fully determined until
running inference.
Throws
IllegalArgumentException |
if
outputIndex
is negtive or is not smaller than the
number of model outputs.
|
---|
public int getOutputTensorCount ()
Gets the number of output Tensors.
public void modifyGraphWithDelegate ( Delegate delegate)
Advanced: Modifies the graph with the provided
Delegate
.
Note: The typical path for providing delegates is via
addDelegate(Delegate)
, at
creation time. This path should only be used when a delegate might require coordinated
interaction between Interpeter creation and delegate application.
WARNING: This is an experimental API and subject to change.
Throws
IllegalArgumentException |
if error occurs when modifying graph with
delegate
.
|
---|
public void resetVariableTensors ()
Advanced: Resets all variable tensors to the default value.
If a variable tensor doesn't have an associated buffer, it will be reset to zero.
WARNING: This is an experimental API and subject to change.
public void resizeInput (int idx, int[] dims, boolean strict)
Resizes idx-th input of the native model to the given dims.
When `strict` is True, only unknown dimensions can be resized. Unknown dimensions are indicated as `-1` in the array returned by `Tensor.shapeSignature()`.
Throws
IllegalArgumentException |
if
idx
is negtive or is not smaller than the number of
model inputs; or if error occurs when resizing the idx-th input. Additionally, the error
occurs when attempting to resize a tensor with fixed dimensions when `struct` is True.
|
---|
public void resizeInput (int idx, int[] dims)
Resizes idx-th input of the native model to the given dims.
Throws
IllegalArgumentException |
if
idx
is negtive or is not smaller than the number of
model inputs; or if error occurs when resizing the idx-th input.
|
---|
public void run (Object input, Object output)
Runs model inference if the model takes only one input, and provides only one output.
Warning: The API is more efficient if a
ERROR(/Buffer)
(preferably direct, but not required)
is used as the input/output data type. Please consider using
ERROR(/Buffer)
to feed and fetch
primitive data for better performance. The following concrete
ERROR(/Buffer)
types are
supported:
-
ByteBuffer
- compatible with any underlying primitive Tensor type. -
ERROR(/FloatBuffer)
- compatible with float Tensors. -
ERROR(/IntBuffer)
- compatible with int32 Tensors. -
ERROR(/LongBuffer)
- compatible with int64 Tensors.
ERROR(/Buffer)
s, or as scalar inputs.
Parameters
input |
an array or multidimensional array, or a
ERROR(/Buffer)
of primitive types
including int, float, long, and byte.
ERROR(/Buffer)
is the preferred way to pass large
input data for primitive types, whereas string types require using the (multi-dimensional)
array input path. When a
ERROR(/Buffer)
is used, its content should remain unchanged until
model inference is done, and the caller must ensure that the
ERROR(/Buffer)
is at the
appropriate read position. A
null
value is allowed only if the caller is using a
Delegate
that allows buffer handle interop, and such a buffer has been bound to the
input
Tensor
.
|
---|---|
output |
a multidimensional array of output data, or a
ERROR(/Buffer)
of primitive types
including int, float, long, and byte. When a
ERROR(/Buffer)
is used, the caller must ensure
that it is set the appropriate write position. A null value is allowed only if the caller
is using a
Delegate
that allows buffer handle interop, and such a buffer has been
bound to the output
Tensor
. See
ERROR(/Options#setAllowBufferHandleOutput())
.
|
Throws
IllegalArgumentException |
if
input
or
output
is null or empty, or if
error occurs when running the inference.
|
---|---|
IllegalArgumentException |
(EXPERIMENTAL, subject to change) if the inference is
interrupted by
setCancelled(true)
.
|
public void runForMultipleInputsOutputs (Object[] inputs, Map<Integer, Object> outputs)
Runs model inference if the model takes multiple inputs, or returns multiple outputs.
Warning: The API is more efficient if
ERROR(/Buffer)
s (preferably direct, but not required)
are used as the input/output data types. Please consider using
ERROR(/Buffer)
to feed and fetch
primitive data for better performance. The following concrete
ERROR(/Buffer)
types are
supported:
-
ByteBuffer
- compatible with any underlying primitive Tensor type. -
ERROR(/FloatBuffer)
- compatible with float Tensors. -
ERROR(/IntBuffer)
- compatible with int32 Tensors. -
ERROR(/LongBuffer)
- compatible with int64 Tensors.
ERROR(/Buffer)
s, or as scalar inputs.
Note:
null
values for invididual elements of
inputs
and
outputs
is
allowed only if the caller is using a
Delegate
that allows buffer handle interop, and
such a buffer has been bound to the corresponding input or output
Tensor
(s).
Parameters
inputs |
an array of input data. The inputs should be in the same order as inputs of the
model. Each input can be an array or multidimensional array, or a
ERROR(/Buffer)
of
primitive types including int, float, long, and byte.
ERROR(/Buffer)
is the preferred way
to pass large input data, whereas string types require using the (multi-dimensional) array
input path. When
ERROR(/Buffer)
is used, its content should remain unchanged until model
inference is done, and the caller must ensure that the
ERROR(/Buffer)
is at the appropriate
read position.
|
---|---|
outputs |
a map mapping output indices to multidimensional arrays of output data or
ERROR(/Buffer)
s of primitive types including int, float, long, and byte. It only needs to keep
entries for the outputs to be used. When a
ERROR(/Buffer)
is used, the caller must ensure
that it is set the appropriate write position.
|
Throws
IllegalArgumentException |
if
inputs
or
outputs
is null or empty, or if
error occurs when running the inference.
|
---|
public void setCancelled (boolean cancelled)
Advanced: Interrupts inference in the middle of a call to
run(Object, Object)
.
A cancellation flag will be set to true when this function gets called. The interpreter will
check the flag between Op invocations, and if it's
true
, the interpreter will stop
execution. The interpreter will remain a cancelled state until explicitly "uncancelled" by
setCancelled(false)
.
WARNING: This is an experimental API and subject to change.
Parameters
cancelled |
true
to cancel inference in a best-effort way;
false
to
resume.
|
---|
Throws
IllegalStateException | if the interpreter is not initialized with the cancellable option, which is by default off. |
---|
public void setNumThreads (int numThreads)
This method was deprecated
in API level
.
Prefer using
setNumThreads(int)
directly for controlling thread
multi-threading. This method will be removed in a future release.
Sets the number of threads to be used for ops that support multi-threading.