tfp.experimental.auto_batching.NumpyBackend

View source on GitHub

Implements the Numpy backend ops for a PC auto-batching VM.

variable_class

Methods

any

View source

assert_matching_dtype

View source

Asserts that the dtype of val matches expected_dtype.

Args
expected_dtype A numpy dtype
val An object convertible to np.array
message Optional diagnostic message.

Raises
ValueError If dtype does not match.

batch_size

View source

Returns the first (batch) dimension of val.

broadcast_to_shape_of

View source

Broadcasts val to the shape of target.

Args
val Python or Numpy array to be broadcast. Must be np.array compatible and broadcast-compatible with target.
target Python or Numpy array whose shape we broadcast val to match.
name Optional name for the op.

Returns
broadcast_val A np.ndarray with shape matching val + target. Provided that val's dimension sizes are all smaller or equal to target's, the returned value will be the shape of target.

cond

View source

Implements a conditional operation for the backend.

Args
pred A Python or Numpy bool scalar indicating the condition.
true_fn A callable accepting and returning nests of np.ndarrays with the same structure as state, to be executed when pred is True.
false_fn A callable accepting and returning nests of np.ndarrays with the same structure as state, to be executed when pred is False.
name Optional name for the op.

Returns
state Output state, matching nest structure of input argument state.

create_variable

View source

Returns an intialized Variable.

Args
name Name for the variable.
alloc VariableAllocation for the variable.
type_ instructions.TensorType describing the sub-batch shape and dtype of the variable being created.
max_stack_depth Python int, the maximum stack depth to enforce.
batch_size Python int, the number of parallel threads being executed.

Returns
var A new, initialized Variable object.

equal

View source

Implements equality comparison for Numpy backend.

fill

View source

Fill a fresh batched Tensor of the given shape and dtype with value.

Args
value Scalar to fill with.
size Scalar int Tensor specifying the number of VM threads.
dtype tf.DType of the zeros to be returned.
shape Rank 1 int Tensor, the per-thread value shape.
name Optional name for the op.

Returns
result Tensor of dtype values with shape [size, *shape]

full_mask

View source

Returns an all-True mask np.ndarray with shape [size].

merge_dtypes

View source

Merges two dtypes, returning a compatible dtype.

Args
dt1 A numpy dtype, or None.
dt2 A numpy dtype, or None.

Returns
dtype The more precise numpy dtype (e.g. prefers int64 over int32).

merge_shapes

View source

Merges two shapes, returning a broadcasted shape.

Args
s1 A list of Python int or None.
s2 A list of Python int or None.

Returns
shape A list of Python int or None.

Raises
ValueError If s1 and s2 are not broadcast compatible.

not_equal

View source

Implements inequality comparison for Numpy backend.

prepare_for_cond

View source

Backend hook for preparing Tensors for cond.

Does nothing in the numpy backend (needed by the TensorFlow backend).

Args
state A state to be prepared for use in conditionals.

Returns
state The prepared state.

reduce_min

View source

Implements reduce_min for Numpy backend.

run_on_dummies

View source

Runs the given primitive_callable with dummy input.

This is useful for examining the outputs for the purpose of type inference.

Args
primitive_callable A python callable.
input_types list of instructions.Type type of each argument to the callable. Note that the contained TensorType objects must match the dimensions with which the primitive is to be invoked at runtime, even though type inference conventionally does not store the batch dimension in the TensorTypes.

Returns
outputs pattern of backend-specific objects whose types may be analyzed by the caller with type_of.

static_value

View source

Gets the eager/immediate value of t.

switch_case

View source

Implements a switch (branch_selector) { case ... } construct.

type_of

View source

Returns the instructions.Type of t.

Args
t np.ndarray or a Python constant.
dtype_hint dtype to prefer, if t is a constant.

Returns
vm_type instructions.TensorType describing t

where

View source

Implements a where selector for the Numpy backend.

Extends tf.where to support broadcasting of on_false.

Args
condition A bool np.ndarray, either a vector having length y.shape[0] or matching the full shape of y.
x np.ndarray of values to take when condition is True.
y np.ndarray of values to take when condition is False. May be smaller than x, as long as it is broadcast-compatible.
name Optional name for the op.

Returns
masked A np.ndarray where indices corresponding to True values in condition come from the corresponding value in x, and others come from y.

while_loop

View source

Implements while loops for Numpy backend.

wrap_straightline_callable

View source

Method exists solely to be stubbed, i.e. for defun or XLA compile.