XLA: Como otimizar o compilador para machine learning

O XLA (álgebra linear acelerada, na sigla em inglês) é um compilador específico de domínio para álgebra linear que pode acelerar modelos do TensorFlow sem necessidade de mudanças no código-fonte.

Isso resulta em melhorias de velocidade e uso de memória: por exemplo, o envio de BERT MLPerf usando GPUs 8 Volta V100 com XLA teve um desempenho 7 vezes melhor e um tamanho do lote 5 vezes melhor.

Introdução

Quando um programa do TensorFlow é executado, todas as operações são executadas individualmente pelo executor do TensorFlow. Cada operation do TensorFlow tem uma implementação de kernel da GPU pré-compilada para onde o executor envia.

O XLA fornece um modo alternativo de execução de modelos: ele compila o grafo do TensorFlow em uma sequência de kernels de computação gerados especificamente para o modelo fornecido. Como esses kernels são exclusivos do modelo, eles podem explorar informações específicas do modelo para otimização. Por exemplo, vamos analisar o XLA de otimização no contexto de um cálculo simples do TensorFlow:

def model_fn(x, y, z):
  return tf.reduce_sum(x + y * z)

Executado sem o XLA, o grafo inicia três kernels: um para a multiplicação, outro para a adição e outro para a redução. No entanto, o XLA pode otimizar o grafo para que ele calcule o resultado em uma única inicialização do kernel. Ele faz isso "combinando" adição, multiplicação e redução em um único kernel da GPU. Além disso, essa operation combinada não grava os valores intermediários produzidos por y*z e x+y*z na memória. Em vez disso, ela "transmite" os resultados desses cálculos intermediários diretamente para os usuários, mantendo-os inteiramente em registros da GPU. A fusão é a otimização mais importante do XLA. A largura de banda de memória normalmente é o recurso mais escasso em aceleradores de hardware. Portanto, remover operações de memória é uma das melhores maneiras de melhorar o desempenho.

Ativar XLA para modelos do TensorFlow

Compilação explícita com tf.function(jit_compile=True)

A API de compilação explícita oferece um controle refinado para escolher quais funções devem ser compiladas. Por exemplo, a seguinte função do TensorFlow que realiza o treinamento MNIST é compilada com o XLA:

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

A API jit_compile tem semântica must-compile: a função inteira é compilada com XLA ou uma exceção errors.InvalidArgumentError é gerada. No momento, o XLA não pode compilar funções em que as dimensões não são inferíveis: isto é, se não for possível inferir as dimensões de todos os tensores sem executar o cálculo completo. Por exemplo, a função a seguir não pode ser compilada:

@tf.function
def not_compilable(x):
  return tf.unique(x)

No entanto, as formas podem variar entre as execuções:

@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
  return a + b

recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10]))
recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))

Consulte o tutorial no Colab para ver um exemplo de uso mais detalhado e um tutorial em vídeo sobre o uso de jit_compile=True.

Clustering automático

Uma forma simples de começar a usar XLA em modelos do TensorFlow sem nenhuma mudança é ativando o clustering automático, que encontra automaticamente clusters (subgrafos conectados) nas funções do TensorFlow que podem ser compiladas e executadas usando XLA. O clustering automático na GPU pode ser ativado ao definir a variável de ambiente TF_XLA_FLAGS:

$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program

Atualmente, o clustering automático é otimizado para cargas de trabalho da GPU, mas também pode ser ativado na CPU usando a sinalização --tf_xla_cpu_global_jit adicionalmente:

$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program

Para ver um exemplo de uso detalhado, consulte o tutorial sobre clustering automático no Colab.

Compilação Ahead-of-time (AOT) para CPU com tfcompile

Também é possível usar uma ferramenta autônoma tfcompile, que converte o grafo do TensorFlow em código executável (somente para CPU x86-64).

Inspecionar programas compilados

O XLA fornece instalações de introspecção que permitem inspecionar os programas gerados. Para despejar os programas gerados, use a variável de ambiente XLA_FLAGS:

$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program

Depois que o despejo for realizado, você poderá encontrar os arquivos a seguir em /tmp/generated:

Também é possível descarregar o grafo que visualiza a incorporação de clusters XLA dentro do grafo do TensorFlow com:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug"

Relatórios de bugs reproduzíveis

É muito mais fácil reproduzir um relatório de bug se ele inclui dumps para os programas XLA gerados e a incorporação de clustering automático usada. Para gerá-los para um programa do TensorFlow executado com clustering automático, execute:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \
  TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
  XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
    my/tensorflow/program"

Ao enviar relatórios de bugs, anexe o conteúdo do diretório /tmp/generated (mencionado acima).

Se possível, tente isolar um bug em um único programa XLA usando replay_computation e executando-o iterativamente nos programas gerados.

Leitura adicional

Front-ends de XLA

Além do TensorFlow, os programas XLA podem ser gerados por:

  • JAX: transformações combináveis de programas Python e NumPy;
  • Julia: a linguagem Julia para computação científica;
  • PyTorch: framework PyTorch
  • Nx: biblioteca de computação numérica para a linguagem de programação Elixir

Palestras

Como usar XLA do TF com jit_compile=True

Visão geral de XLA