TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

tfp.experimental.auto_batching.TensorFlowBackend

View source on GitHub

Class TensorFlowBackend

Implements the TF backend ops for a PC auto-batching VM.

Aliases:

__init__

View source

__init__(
    safety_checks=True,
    while_parallel_iterations=10,
    while_maximum_iterations=None,
    basic_block_xla_device=None
)

Construct a new backend instance.

Args:

Properties

variable_class

Methods

any

View source

any(
    t,
    name=None
)

assert_matching_dtype

View source

assert_matching_dtype(
    expected_dtype,
    value,
    message=''
)

Asserts that the dtype of value matches expected_dtype.

Args:

  • expected_dtype: A numpy dtype
  • value: Tensor or convertible.
  • message: Optional diagnostic message.

Raises:

  • ValueError: If dtype does not match.

batch_size

View source

batch_size(
    value,
    name=None
)

Returns the first (batch) dimension of value.

broadcast_to_shape_of

View source

broadcast_to_shape_of(
    val,
    target,
    name=None
)

Broadcasts val to the shape of target.

Attempts to match the dtype of broadcast_val to the dtype of target, if val is not a Tensor and target has a dtype.

Args:

  • val: The value to be broadcast. Must be broadcast-compatible with target.
  • target: Tensor whose shape we will broadcast val to match.
  • name: Optional name for the op.

Returns:

  • broadcast_val: A Tensor with shape matching val + target. Provided that val's dimension sizes are all smaller or equal to target's, the returned value will be the shape of target.

cond

View source

cond(
    pred,
    true_fn,
    false_fn,
    name=None
)

Implements a conditional operation for the backend.

Args:

  • pred: A boolean scalar Tensor indicating the condition.
  • true_fn: A callable accepting and returning nests of Tensors having the same structure as state, to be executed when pred is True.
  • false_fn: A callable accepting and returning nests of Tensors having the same structure as state, to be executed when pred is False.
  • name: Optional name for the op.

Returns:

  • state: Output state, matching nest structure of input argument state.

create_variable

View source

create_variable(
    name,
    alloc,
    type_,
    max_stack_depth,
    batch_size
)

Returns an intialized Variable.

Args:

  • name: Name for the variable.
  • alloc: VariableAllocation for the variable.
  • type_: instructions.TensorType describing the sub-batch shape and dtype of the variable being created.
  • max_stack_depth: Scalar int Tensor, the maximum stack depth allocated.
  • batch_size: Scalar int Tensor, the number of parallel threads being executed.

Returns:

  • var: A new, initialized Variable object.

equal

View source

equal(
    t1,
    t2,
    name=None
)

Implements equality comparison for TF backend.

fill

View source

fill(
    value,
    size,
    dtype,
    shape,
    name=None
)

Fill a fresh batched Tensor of the given shape and dtype with value.

Args:

  • value: Scalar to fill with.
  • size: Scalar int Tensor specifying the number of VM threads.
  • dtype: tf.DType of the zeros to be returned.
  • shape: Rank 1 int Tensor, the per-thread value shape.
  • name: Optional name for the op.

Returns:

  • result: Tensor of dtype values with shape [size, *shape]

full_mask

View source

full_mask(
    size,
    name=None
)

Returns an all-True mask Tensor with shape [size].

merge_dtypes

View source

merge_dtypes(
    dt1,
    dt2
)

Merges two dtypes, returning a compatible dtype.

In practice, TF implementation asserts that the two dtypes are identical.

Args:

  • dt1: A numpy dtype, or None.
  • dt2: A numpy dtype, or None.

Returns:

  • dtype: The common numpy dtype.

Raises:

  • ValueError: If dt1 and dt2 are not equal and both are non-None.

merge_shapes

View source

merge_shapes(
    s1,
    s2
)

Merges two shapes, returning a broadcasted shape.

Args:

  • s1: A list of Python int or None.
  • s2: A list of Python int or None.

Returns:

  • shape: A list of Python int or None.

Raises:

  • ValueError: If s1 and s2 are not broadcast compatible.

not_equal

View source

not_equal(
    t1,
    t2,
    name=None
)

Implements inequality comparison for TF backend.

prepare_for_cond

View source

prepare_for_cond(state)

Backend hook for preparing Tensors for cond.

The TensorFlow backend uses this hook to apply tf.convert_to_tensor before entering the cond tree generated by virtual_machine._staged_apply. One could do this inside cond, but when this API element was defined there seemed to be a performance reason (for Eager mode) to do it once per cond tree rather than once per cond.

Args:

  • state: A state to be prepared for use in conditionals.

Returns:

  • state: The prepared state.

reduce_min

View source

reduce_min(
    t,
    name=None
)

Implements reduce_min for TF backend.

run_on_dummies

View source

run_on_dummies(
    primitive_callable,
    input_types
)

Runs the given primitive_callable with dummy input.

This is useful for examining the outputs for the purpose of type inference.

Args:

  • primitive_callable: A python callable.
  • input_types: list of instructions.Type type of each argument to the callable. Note that the contained TensorType objects must match the dimensions with which the primitive is to be invoked at runtime, even though type inference conventionally does not store the batch dimension in the TensorTypes.

Returns:

  • outputs: pattern of backend-specific objects whose types may be analyzed by the caller with type_of.

static_value

View source

static_value(t)

Gets the eager/immediate value of t, or None if t is a Tensor.

switch_case

View source

switch_case(
    branch_selector,
    branch_callables,
    name=None
)

Implements a switch (branch_selector) { case ... } construct.

type_of

View source

type_of(
    t,
    dtype_hint=None
)

Returns the instructions.Type of t.

Args:

  • t: tf.Tensor or a Python or numpy constant.
  • dtype_hint: dtype to prefer, if t is a constant.

Returns:

where

View source

where(
    condition,
    x,
    y,
    name=None
)

Implements a where selector for the TF backend.

Attempts to match the dtypes of the value operands, if they are not yet both Tensors.

Args:

  • condition: A boolean Tensor, either a vector having length (x + y).shape[0] or matching the full shape of x + y.
  • x: Tensor of values to take when condition is True. Shape must match that of y.
  • y: Tensor of values to take when condition is False. Shape must match that of x.
  • name: Optional name for the op.

Returns:

  • masked: A broadcast-shaped Tensor where elements corresponding to True values of condition come from x, and others come from y.

while_loop

View source

while_loop(
    cond,
    body,
    loop_vars,
    name=None
)

Implements while loops for TF backend.

wrap_straightline_callable

View source

wrap_straightline_callable(f)

Method exists solely to be stubbed, i.e. for defun + XLA compile.