![]() |
Returns (batched) matmul of a SparseTensor (or Tensor) with a Tensor.
tfp.experimental.substrates.jax.math.linalg.sparse_or_dense_matmul(
sparse_or_dense_a,
dense_b,
validate_args=False,
name=None,
**kwargs
)
Args:
sparse_or_dense_a
:SparseTensor
orTensor
representing a (batch of) matrices.dense_b
:Tensor
representing a (batch of) matrices, with the same batch shape assparse_or_dense_a
. The shape must be compatible with the shape ofsparse_or_dense_a
and kwargs.validate_args
: WhenTrue
, additional assertions might be embedded in the graph. Default value:False
(i.e., no graph assertions are added).name
: Pythonstr
prefixed to ops created by this function. Default value: 'sparse_or_dense_matmul'.**kwargs
: Keyword arguments totf.sparse_tensor_dense_matmul
ortf.matmul
.
Returns:
product
: A dense (batch of) matrix-shaped Tensor of the same batch shape and dtype assparse_or_dense_a
anddense_b
. Ifsparse_or_dense_a
ordense_b
is adjointed throughkwargs
then the shape is adjusted accordingly.