tf.contrib.gan.estimator.GANHead

Class GANHead

Aliases:

  • Class tf.contrib.gan.estimator.GANHead
  • Class tf.contrib.gan.estimator.head.GANHead

Defined in tensorflow/contrib/gan/python/estimator/python/head_impl.py.

Head for a GAN.

Properties

logits_dimension

name

Methods

__init__

__init__(
    generator_loss_fn,
    discriminator_loss_fn,
    generator_optimizer,
    discriminator_optimizer,
    use_loss_summaries=True,
    get_hooks_fn=None,
    name=None
)

Head for GAN training.

Args:

  • generator_loss_fn: A TFGAN loss function for the generator. Takes a GANModel and returns a scalar.
  • discriminator_loss_fn: Same as generator_loss_fn, but for the discriminator.
  • generator_optimizer: The optimizer for generator updates.
  • discriminator_optimizer: Same as generator_optimizer, but for the discriminator updates.
  • use_loss_summaries: If True, add loss summaries. If False, does not. If None, uses defaults.
  • get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list of hooks. Defaults to train.get_sequential_train_hooks()
  • name: name of the head. If provided, summary and metrics keys will be suffixed by "/" + name.

create_estimator_spec

create_estimator_spec(
    features,
    mode,
    logits,
    labels=None,
    train_op_fn=tf.contrib.gan.gan_train_ops
)

Returns EstimatorSpec that a model_fn can return.

See Head for more details.

Args:

  • features: Must be None.
  • mode: Estimator's ModeKeys.
  • logits: A GANModel tuple.
  • labels: Must be None.
  • train_op_fn: Function that takes a GANModel, GANLoss, generator optimizer, and discriminator optimizer, and returns a GANTrainOps tuple. For example, this function can come from TFGAN's train.py library, or can be custom.

Returns:

EstimatorSpec.

Raises:

  • ValueError: If features isn't None.
  • ValueError: If train_op_fn isn't provided in train mode.

create_loss

create_loss(
    features,
    mode,
    logits,
    labels
)

Returns a GANLoss tuple from the provided GANModel.

See Head for more details.

Args:

  • features: Input dict of Tensor objects. Unused.
  • mode: Estimator's ModeKeys.
  • logits: A GANModel tuple.
  • labels: Must be None.

Returns:

A GANLoss tuple.