tf.compat.v1.scatter_update to axis different than 0. (deprecated)
tf.batch_scatter_update( ref, indices, updates, use_locking=True, name=None )
batch_gather. This assumes that
have a series of leading dimensions that are the same for all of them, and the
updates are performed on the last dimension of indices. In other words, the
dimensions should be the following:
num_prefix_dims = indices.ndims - 1
batch_dim = num_prefix_dims + 1
updates.shape = indices.shape + var.shape[batch_dim:]
And the operation performed can be expressed as:
var[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j]
When indices is a 1D tensor, this operation is equivalent to
To avoid this operation there would be 2 alternatives:
1) Reshaping the variable by merging the first
ndims dimensions. However,
this is not possible because
tf.reshape returns a Tensor, which we
2) Looping over the first
ndims of the variable and using
tf.compat.v1.scatter_update on the subtensors that result of slicing the
dimension. This is a valid option for
ndims = 1, but less efficient than
Variableto scatter onto.
indices: Tensor containing indices as described above.
updates: Tensor of updates to apply to
use_locking: Boolean indicating whether to lock the writing operation.
name: Optional scope name string.
variable after it has been modified.
ValueError: If the initial
updatesare not the same.