Have a question? Connect with the community at the TensorFlow Forum Visit Forum

tfnlp.layers.TNTransformerExpandCondense

Transformer layer using tensor network Expand-Condense layer.

This layer implements the Transformer from transformer.py, with a single tensor network layer replacing the usual intermediate and output Dense layers.

num_attention_heads Number of attention heads.
intermediate_size Size of the intermediate layer.
intermediate_activation Activation for the intermediate layer.
dropout_rate Dropout probability for the post-attention and output dropout.
attention_dropout_rate Dropout probability for within the attention layer.
output_range the sequence output range, [0, output_range) by slicing the target sequence. None means the target sequence is not sliced.
kernel_initializer Initializer for dense layer kernels.
bias_initializer Initializer for dense layer biases.
kernel_regularizer Regularizer for dense layer kernels.
bias_regularizer Regularizer for dense layer biases.
activity_regularizer Regularizer for dense layer activity.
kernel_constraint Constraint for dense layer kernels.
bias_constraint Constraint for dense layer kernels.
use_bias Whether to enable use_bias in attention layer. If set to False, use_bias in attention layer is disabled.
norm_first Whether to normalize inputs to attention and intermediate dense layers. If set False, output of attention and intermediate dense layers is normalized.
norm_epsilon Epsilon value to initialize normalization layers.
intermediate_dropout Dropout probability for intermediate_dropout_layer.
attention_initializer Initializer for kernels of attention layers. If set None, attention layers use kernel_initializer as initializer for kernel.

Methods

call

View source

This is where the layer's logic lives.

Note here that call() method in tf.keras is little bit different from keras API. In keras API, you can pass support masking for layers as additional arguments. Whereas tf.keras has compute_mask() method to support masking.

Args
inputs Input tensor, or list/tuple of input tensors.
*args Additional positional arguments. Currently unused.
**kwargs Additional keyword arguments. Currently unused.

Returns
A tensor or list/tuple of tensors.