TensorFlow 2.0 Beta is available Learn more

tf.contrib.gan.cyclegan_loss

View source on GitHub

Returns the losses for a CycleGANModel.

tf.contrib.gan.cyclegan_loss(
    model,
    generator_loss_fn=tf.contrib.gan.losses.least_squares_generator_loss,
    discriminator_loss_fn=tf.contrib.gan.losses.least_squares_discriminator_loss,
    cycle_consistency_loss_fn=tf.contrib.gan.losses.cycle_consistency_loss,
    cycle_consistency_loss_weight=10.0,
    **kwargs
)

See https://arxiv.org/abs/1703.10593 for more details.

Args:

  • model: A CycleGANModel namedtuple.
  • generator_loss_fn: The loss function on the generator. Takes a GANModel named tuple.
  • discriminator_loss_fn: The loss function on the discriminator. Takes a GANModel namedtuple.
  • cycle_consistency_loss_fn: The cycle consistency loss function. Takes a CycleGANModel namedtuple.
  • cycle_consistency_loss_weight: A non-negative Python number or a scalar Tensor indicating how much to weigh the cycle consistency loss.
  • **kwargs: Keyword args to pass directly to gan_loss to construct the loss for each partial model of model.

Returns:

A CycleGANLoss namedtuple.

Raises:

  • ValueError: If model is not a CycleGANModel namedtuple.