![]() |
Multi-linear interpolation on a regular (constant spacing) grid.
tfp.math.batch_interp_regular_nd_grid(
x, x_ref_min, x_ref_max, y_ref, axis, fill_value='constant_extension',
name=None
)
Given [a batch of] reference values, this function computes a multi-linear
interpolant and evaluates it on [a batch of] of new x
values.
The interpolant is built from reference values indexed by nd
dimensions
of y_ref
, starting at axis
.
For example, take the case of a 2-D
scalar valued function and no leading
batch dimensions. In this case, y_ref.shape = [C1, C2]
and y_ref[i, j]
is the reference value corresponding to grid point
[x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1),
x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)]
In the general case, dimensions to the left of axis
in y_ref
are broadcast
with leading dimensions in x
, x_ref_min
, x_ref_max
.
Args | |
---|---|
x
|
Numeric Tensor The x-coordinates of the interpolated output values for
each batch. Shape [..., D, nd] , designating [a batch of] D
coordinates in nd space. D must be >= 1 and is not a batch dim.
|
x_ref_min
|
Tensor of same dtype as x . The minimum values of the
(implicitly defined) reference x_ref . Shape [..., nd] .
|
x_ref_max
|
Tensor of same dtype as x . The maximum values of the
(implicitly defined) reference x_ref . Shape [..., nd] .
|
y_ref
|
Tensor of same dtype as x . The reference output values. Shape
[..., C1, ..., Cnd, B1,...,BM] , designating [a batch of] reference
values indexed by nd dimensions, of a shape [B1,...,BM] valued
function (for M >= 0 ).
|
axis
|
Scalar integer Tensor . Dimensions [axis, axis + nd) of y_ref
index the interpolation table. E.g. 3-D interpolation of a scalar
valued function requires axis=-3 and a 3-D matrix valued function
requires axis=-5 .
|
fill_value
|
Determines what values output should take for x values that
are below x_ref_min or above x_ref_max . Scalar Tensor or
'constant_extension' ==> Extend as constant function.
Default value: 'constant_extension'
|
name
|
A name to prepend to created ops.
Default value: 'batch_interp_regular_nd_grid' .
|
Returns | |
---|---|
y_interp
|
Interpolation between members of y_ref , at points x .
Tensor of same dtype as x , and shape [..., D, B1, ..., BM].
|
Raises | |
---|---|
ValueError
|
If rank(x) < 2 is determined statically.
|
ValueError
|
If axis is not a scalar is determined statically.
|
ValueError
|
If axis + nd > rank(y_ref) is determined statically.
|
Examples
Interpolate a function of one variable.
y_ref = tf.exp(tf.linspace(start=0., stop=10., num=20))
tfp.math.batch_interp_regular_nd_grid(
# x.shape = [3, 1], x_ref_min/max.shape = [1]. Trailing `1` for `1-D`.
x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[10.], y_ref=y_ref,
axis=0)
==> approx [exp(6.0), exp(0.5), exp(3.3)]
Interpolate a scalar function of two variables.
x_ref_min = [0., 0.]
x_ref_max = [2 * np.pi, 2 * np.pi]
# Build y_ref.
x0s, x1s = tf.meshgrid(
tf.linspace(x_ref_min[0], x_ref_max[0], num=100),
tf.linspace(x_ref_min[1], x_ref_max[1], num=100),
indexing='ij')
def func(x0, x1):
return tf.sin(x0) * tf.cos(x1)
y_ref = func(x0s, x1s)
x = np.pi * tf.random.uniform(shape=(10, 2))
tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2)
==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])