Batches all the inputs tensors to the computation done by the function.
tf.raw_ops.BatchFunction(
in_tensors,
captured_tensors,
f,
num_batch_threads,
max_batch_size,
batch_timeout_micros,
Tout,
max_enqueued_batches=10,
allowed_batch_sizes=[],
container='',
shared_name='',
batching_queue='',
enable_large_batch_splitting=False,
name=None
)
So, for example, in the following code
# This input will be captured.
y = tf.placeholder_with_default(1.0, shape=[])
@tf.Defun(tf.float32)
def computation(a):
return tf.matmul(a, a) + y
b = gen_batch_ops.batch_function(
f=computation
in_tensors=[a],
captured_tensors=computation.captured_inputs,
Tout=[o.type for o in computation.definition.signature.output_arg],
num_batch_threads=1,
max_batch_size=10,
batch_timeout_micros=100000, # 100ms
allowed_batch_sizes=[3, 10],
batching_queue="")
If more than one session.run call is simultaneously trying to compute b
the values of a
will be gathered, non-deterministically concatenated
along the first axis, and only one thread will run the computation.
Assumes that all arguments of the function are Tensors which will be batched along their first dimension.
Arguments that are captured, are not batched. The session.run call which does the concatenation, will use the values of the captured tensors available to it. Therefore, typical uses of captured tensors should involve values which remain unchanged across session.run calls. Inference is a good example of this.
SparseTensor is not supported. The return value of the decorated function must be a Tensor or a list/tuple of Tensors.
Returns | |
---|---|
A list of Tensor objects of type Tout .
|