Watch talks from the 2019 TensorFlow Dev Summit Watch now

tf.contrib.gan.cyclegan_loss

tf.contrib.gan.cyclegan_loss(
    model,
    generator_loss_fn=tf.contrib.gan.losses.least_squares_generator_loss,
    discriminator_loss_fn=tf.contrib.gan.losses.least_squares_discriminator_loss,
    cycle_consistency_loss_fn=tf.contrib.gan.losses.cycle_consistency_loss,
    cycle_consistency_loss_weight=10.0,
    **kwargs
)

Defined in tensorflow/contrib/gan/python/train.py.

Returns the losses for a CycleGANModel.

See https://arxiv.org/abs/1703.10593 for more details.

Args:

  • model: A CycleGANModel namedtuple.
  • generator_loss_fn: The loss function on the generator. Takes a GANModel named tuple.
  • discriminator_loss_fn: The loss function on the discriminator. Takes a GANModel namedtuple.
  • cycle_consistency_loss_fn: The cycle consistency loss function. Takes a CycleGANModel namedtuple.
  • cycle_consistency_loss_weight: A non-negative Python number or a scalar Tensor indicating how much to weigh the cycle consistency loss.
  • **kwargs: Keyword args to pass directly to gan_loss to construct the loss for each partial model of model.

Returns:

A CycleGANLoss namedtuple.

Raises:

  • ValueError: If model is not a CycleGANModel namedtuple.