Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings


View source on GitHub

Generates adversarial neighbors for the given input and loss.


This function implements the following operation: adv_neighbor = input_features + adv_step_size * gradient where adv_step_size is the step size (analogous to learning rate) for searching/calculating adversarial neighbor.


  • input_features: A Tensor or a dictionary of (feature_name, Tensor). The shape of the tensor(s) should be either: (a) pointwise samples: [batch_size, feat_len], or (b) sequence samples: [batch_size, seq_len, feat_len]. Note that only dense (float) tensors in input_features will be perturbed and all other features (int, string, or SparseTensor) will be kept as-is in the returning adv_neighbor.
  • labeled_loss: A scalar tensor of floating point type calculated from true labels (or supervisions).
  • config: A nsl.configs.AdvNeighborConfig object containing the following hyperparameters for generating adversarial samples.
    • 'feature_mask': mask (with 0-1 values) applied on the graident.
    • 'adv_step_size': step size to find the adversarial sample.
    • 'adv_grad_norm': type of tensor norm to normalize the gradient.
  • raise_invalid_gradient: (optional) A Boolean flag indicating whether to raise an error when gradients cannot be computed on any input feature. There are three cases where this error may happen: (1) The feature is a SparseTensor. (2) The feature has a non-differentiable dtype, like string or integer. (3) The feature is not involved in loss computation. If set to False (default), those inputs without gradient will be ignored silently and not perturbed.
  • gradient_tape: A tf.GradientTape object watching the calculation from input_features to labeled_loss. Can be omitted if running in graph mode.


  • adv_neighbor: The perturbed example, with the same shape and structure as input_features.
  • adv_weight: A dense Tensor with shape of [batch_size, 1], representing the weight for each neighbor.


  • ValueError: In case of raise_invalid_gradient is set and some of the input features cannot be perturbed. See raise_invalid_gradient for situations where this can happen.