View source on GitHub |
Multi-linear interpolation on a rectilinear grid.
tfp.substrates.jax.math.batch_interp_rectilinear_nd_grid(
x,
x_grid_points,
y_ref,
axis,
fill_value='constant_extension',
name=None
)
Given [a batch of] reference values, this function computes a multi-linear
interpolant and evaluates it on [a batch of] new x
values. This is a
multi-dimensional generalization of Bilinear Interpolation.
The interpolant is built from reference values indexed by nd
dimensions
of y_ref
, starting at axis
.
The x grid is defined by 1-D
points along each dimension. These points must
be sorted, but may have unequal spacing.
For example, take the case of a 2-D
scalar valued function and no leading
batch dimensions. In this case, y_ref.shape = [C1, C2]
and y_ref[i, j]
is the reference value corresponding to grid point
[x_grid_points[0][i], x_grid_points[1][j]]
In the general case, dimensions to the left of axis
in y_ref
are broadcast
with leading dimensions in x
, and x_grid_points[k]
, k = 0, ..., nd - 1
.
Returns | |
---|---|
y_interp
|
Interpolation between members of y_ref , at points x .
Tensor of same dtype as x , and shape [..., D, B1, ..., BM].
|
Exceptions will be raised if shapes are statically determined to be wrong.
Raises | |
---|---|
ValueError
|
If rank(x) < 2
|
ValueError
|
If axis is not a scalar.
|
ValueError
|
If axis + nd > rank(y_ref) .
|
ValueError
|
If x_grid_points[k].shape[-1] != y_ref.shape[axis + k] .
|
Examples
Interpolate a function of one variable.
x_grid = tf.linspace(0., 1., 20)**2 # Nonlinearly spaced
y_ref = tf.exp(x_grid)
tfp.math.batch_interp_rectilinear_nd_grid(
# x.shape = [3, 1], with the trailing `1` for `1-D`.
x=[[6.0], [0.5], [3.3]], x_grid_points=(x_grid,), y_ref=y_ref, axis=0)
==> approx [exp(6.0), exp(0.5), exp(3.3)]
Interpolate a scalar function of two variables.
x0_grid = tf.linspace(0., 2 * np.pi, num=100),
x1_grid = tf.linspace(0., 2 * np.pi, num=100),
# Build y_ref.
x0s, x1s = tf.meshgrid(x0_grid, x1_grid, indexing='ij')
def func(x0, x1):
return tf.sin(x0) * tf.cos(x1)
y_ref = func(x0s, x1s)
x = np.pi * tf.random.stateless_uniform(shape=(10, 2))
tfp.math.batch_interp_regular_nd_grid(x, x_grid_points=(x0_grid, x1_grid),
y_ref, axis=-2)
==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])