TensorFlow 2.0 Beta is available Learn more

tf.contrib.gan.losses.wargs.acgan_generator_loss

View source on GitHub

ACGAN loss for the generator.

tf.contrib.gan.losses.wargs.acgan_generator_loss(
    discriminator_gen_classification_logits,
    one_hot_labels,
    weights=1.0,
    scope=None,
    loss_collection=tf.GraphKeys.LOSSES,
    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
    add_summaries=False
)

The ACGAN loss adds a classification loss to the conditional discriminator. Therefore, the discriminator must output a tuple consisting of (1) the real/fake prediction and (2) the logits for the classification (usually the last conv layer, flattened).

For more details:

  • ACGAN: https://arxiv.org/abs/1610.09585

Args:

  • discriminator_gen_classification_logits: Classification logits for generated data.
  • one_hot_labels: A Tensor holding one-hot labels for the batch.
  • weights: Optional Tensor whose rank is either 0, or the same rank as discriminator_gen_classification_logits, and must be broadcastable to discriminator_gen_classification_logits (i.e., all dimensions must be either 1, or the same as the corresponding dimension).
  • scope: The scope for the operations performed in computing the loss.
  • loss_collection: collection to which this loss will be added.
  • reduction: A tf.compat.v1.losses.Reduction to apply to loss.
  • add_summaries: Whether or not to add summaries for the loss.

Returns:

A loss Tensor. Shape depends on reduction.

Raises:

  • ValueError: if arg module not either generator or discriminator
  • TypeError: if the discriminator does not output a tuple.