|TensorFlow 2 version||View source on GitHub|
Tensor contraction over specified indices and outer product.
tf.einsum( equation, *inputs, **kwargs )
This function returns a tensor whose elements are defined by
which is written in a shorthand form inspired by the Einstein summation
convention. As an example, consider multiplying two matrices
A and B to form a matrix C. The elements of C are given by:
C[i,k] = sum_j A[i,j] * B[j,k]
In general, the
equation is obtained from the more familiar element-wise
1. removing variable names, brackets, and commas,
2. replacing "*" with ",",
3. dropping summation signs, and
4. moving the output to the right, and replacing "=" with "->".
Many common operations can be expressed in this way. For example:
# Matrix multiplication >>> einsum('ij,jk->ik', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k] # Dot product >>> einsum('i,i->', u, v) # output = sum_i u[i]*v[i] # Outer product >>> einsum('i,j->ij', u, v) # output[i,j] = u[i]*v[j] # Transpose >>> einsum('ij->ji', m) # output[j,i] = m[i,j] # Trace >>> einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i] # Batch matrix multiplication >>> einsum('aij,ajk->aik', s, t) # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
To enable and control broadcasting, use an ellipsis. For example, to do batch matrix multiplication, you could use:
einsum('...ij,...jk->...ik', u, v)
This function behaves like
numpy.einsum, but does not support:
- Subscripts where an axis appears more than once for a single input
ijj,k->ik) unless it is a trace (e.g.
strdescribing the contraction, in the same format as
*inputs: the inputs to contract (each one a
Tensor), whose shapes should be consistent with
name: A name for the operation (optional).
Tensor, with shape determined by
- the format of
- the number of inputs implied by
equationdoes not match
- an axis appears in the output subscripts but not in any of the inputs,
- the number of dimensions of an input differs from the number of indices in its subscript, or
- the input shapes are inconsistent along a particular axis.
- the format of