|View source on GitHub|
The gradient penalty for the Wasserstein discriminator loss.
tf.contrib.gan.losses.wargs.wasserstein_gradient_penalty( real_data, generated_data, generator_inputs, discriminator_fn, discriminator_scope, epsilon=1e-10, target=1.0, one_sided=False, weights=1.0, scope=None, loss_collection=tf.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, add_summaries=False )
Improved Training of Wasserstein GANs
(https://arxiv.org/abs/1704.00028) for more details.
real_data: Real data.
generated_data: Output of the generator.
generator_inputs: Exact argument to pass to the generator, which is used as optional conditioning to the discriminator.
discriminator_fn: A discriminator function that conforms to TF-GAN API.
discriminator_scope: If not
None, reuse discriminators from this scope.
epsilon: A small positive number added for numerical stability when computing the gradient norm.
target: Optional Python number or
Tensorindicating the target value of gradient norm. Defaults to 1.0.
True, penalty proposed in https://arxiv.org/abs/1709.08894 is used. Defaults to
Tensorwhose rank is either 0, or the same rank as
generated_data, and must be broadcastable to them (i.e., all dimensions must be either
1, or the same as the corresponding dimension).
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
tf.compat.v1.losses.Reductionto apply to loss.
add_summaries: Whether or not to add summaries for the loss.
A loss Tensor. The shape depends on
ValueError: If the rank of data Tensors is unknown.