Masked matmul router using experts choose tokens assignment.

This router uses the same mechanism as in Mixture-of-Experts with Expert Choice ( each expert selects its top expert_capacity tokens. An individual token may be processed by multiple experts or none at all.

Uses Keras add_loss() and add_metric() APIs.

num_experts Number of experts.
jitter_noise Amplitude of jitter noise applied to router logits.
use_bias Whether or not to use the bias term in computing the router weights.
kernel_initializer Kernel initializer for router weights.
bias_initializer Bias initializer for router weights.
router_z_loss_weight Weight for router_z_loss. Use non-zero values if running into training instability (esp. with dtype 'bfloat16' or lower).
export_metrics Whether to export metrics using Keras add_metric API.
name Layer name.
**kwargs Forwarded to super.



View source

Computes dispatch and combine arrays for routing to experts.

inputs Inputs to send to experts of shape [num_groups, tokens_per_group, hidden_dim].
expert_capacity Each group will send this many tokens to each expert.
training If true, apply jitter noise during routing. If not provided taken from tf.keras.backend.

Router indices or mask arrays (depending on router type).