Chamadas personalizadas do XLA

Este documento descreve como criar e usar chamadas personalizadas do XLA. As chamadas personalizadas permitem invocar o código escrito em uma linguagem de programação, como C++ ou CUDA, de um programa do XLA.

Criar uma chamada personalizada na CPU

É possível criar uma instrução HLO que representa uma chamada personalizada por meio da API cliente do XLA. Por exemplo, o código a seguir usa uma chamada personalizada para calcular A[i] = B[i % 128]+ C[i] na CPU. (É claro que você pode - e deve! – faça isso com HLO normal.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);
  const float* in1 = reinterpret_cast<const float*>(in[1]);
  for (int i = 0; i < 2048; ++i) {
    out_buf[i] = in0[i % 128] + in1[i];
  }
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

A função do_custom_call precisa saber as dimensões dos buffers em que opera. Neste exemplo, fixamos no código os tamanhos 128 e 2048. Se não quiser fazer isso, transmita as dimensões como parâmetros para a chamada.

Criar uma chamada personalizada na GPU

O framework de chamada personalizada da GPU é um pouco diferente do que está na CPU. Veja um exemplo de CUDA que faz o mesmo cálculo (A[i] = B[i % 128] + C[i]) que o código de CPU acima.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");

Primeiro, observe que a função de chamada personalizada da GPU ainda é uma função executada na CPU. A função da CPU do_custom_call é responsável por enfileirar o trabalho na GPU. Aqui, ele inicia um kernel CUDA, mas também pode fazer outra coisa, como chamar cuBLAS.

buffers é uma matriz de ponteiros que reside no host, e cada elemento nele contém pontos para a memória do dispositivo (ou seja, GPU). Os parâmetros vêm primeiro, seguidos pelo valor de saída. Isso é notavelmente diferente da convenção de chamada da CPU, que tem dois parâmetros, ins e out. A convenção de chamada da GPU possibilita o processamento eficiente de entradas/saídas em forma de tupla.

Como no exemplo da CPU, os tamanhos dos buffers de entrada e saída foram fixados na chamada personalizada. No entanto, ao contrário do caso da CPU, transmitir os tamanhos de buffer como operadores para a chamada personalizada não funcionaria bem. Normalmente, precisamos dos tamanhos de buffer disponíveis na CPU. Por exemplo, ao iniciar um kernel, precisamos saber as dimensões de bloco/grade a serem usadas. No entanto, se transmitissemos os tamanhos de buffer como operadores para nossa chamada personalizada, os valores deles ficariam na memória da GPU. Nesse caso, precisaríamos fazer uma memcpy síncrona de dispositivo para host de alto custo no início da nossa operação apenas para ler os tamanhos.

Para contornar esse problema, fornecemos o parâmetro opaque. Você pode defini-lo como uma string arbitrária de bytes ao criar a chamada personalizada:

std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
                opaque);

Como xla::Shape tem uma representação de buffer de protocolo, é possível armazenar esse proto serializado dentro de opaque e desserializá-lo na chamada personalizada da GPU. No entanto, embora xla::ShapeProto não mude com frequência, ele muda. Verifique o registro do Git para ver as mudanças no passado.

Como sinalizar um erro

Se a chamada personalizada encontrar um erro, você poderá sinalizá-lo para o ambiente de execução do XLA (em vez de gerar falhas ou retornar erros nos buffers de saída) usando a seguinte assinatura para a função:

Na CPU:

#include "xla/service/custom_call_status.h"

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);

Na GPU:

#include "xla/service/custom_call_status.h"

void do_custom_call(CUstream stream, void** buffers, const char* opaque,
                    size_t opaque_len, xla::XlaCustomCallStatus* status);

Você pode sinalizar falhas usando XlaCustomCallStatusSetFailure, por exemplo:

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
  // ... do some work.

  if (bad_condition) {
    char* error_message = "An error occurred";
    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
    return;
  }

  // ... continue.
}

Também é possível usar XlaCustomCallStatusSetSuccess para indicar sucesso, mas o XlaCustomCallStatus está com um estado de sucesso por padrão. Portanto, ignorá-lo completamente também indica sucesso.

Ao usar funções de chamada personalizadas com essa assinatura, é necessário criar a operação custom-call correspondente com o conjunto de versões da API apropriado, por exemplo:

xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
                opaque, /*has_side_effect=*/false,
                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
                /*api_version=*/API_VERSION_STATUS_RETURNING);

Em caso de falha, nenhuma das saídas de chamada personalizadas será usada. O ambiente de execução do XLA encerrará o cálculo. Não é possível que um cálculo de HLO seja recuperado do erro (por exemplo, capturando e processando-o).

Como transmitir tuplas para chamadas personalizadas

Considere a chamada personalizada a seguir.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);

Na CPU e na GPU, uma tupla é representada na memória como uma matriz de ponteiros. No pseudocódigo C++, o parâmetro 0 acima é mostrado da seguinte maneira.

// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];

void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;

void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;

Embora a representação na memória de tuplas seja a mesma na CPU e na GPU, elas são tratadas de forma diferente nas convenções de chamada de chamada personalizada da CPU e da GPU.

Saídas de tuplas como buffers de temperatura

As entradas de tuplas para chamadas personalizadas são uma conveniência, mas não são estritamente necessárias. Caso as entradas de tupla não sejam compatíveis com chamadas personalizadas, você pode descompactar as tuplas usando get-tuple-element antes de transmiti-las para a chamada personalizada.

Por outro lado, com as outputs da tupla, é possível fazer coisas que não seria possível de outra forma.

A razão óbvia para ter saídas de tupla é como uma chamada personalizada (ou qualquer outra op do XLA) retorna várias matrizes independentes.

Mas, menos óbvio, uma saída de tupla também é uma maneira de fornecer memória temporária de chamada personalizada. Sim, uma saída pode representar um buffer temporário. Considere que um buffer de saída tem a propriedade em que a operação pode gravar e pode ler a partir dele depois de gravado. Isso é exatamente o que você quer de um buffer temporário.

No exemplo acima, suponha que você queira usar a F32[1024] como um buffer temporário. Em seguida, criaríamos o HLO como mostrado acima, e nunca leríamos o índice de tupla 1 da saída da chamada personalizada.

Tuplas em chamadas personalizadas da CPU

No código da CPU, temos uma função do_custom_call(const void** ins, void* out). ins é uma matriz com apenas um elemento, que aponta para param0. Os subbuffers de param0 são acessíveis desreferenciando esse ponteiro, e os subbuffers de output_tuple são acessíveis ao desreferenciar out.

Tuplas em chamadas personalizadas da GPU

No código da GPU, temos uma função do_custom_call(..., void** buffers, ...). Nesse caso, buffers é uma matriz de host de seis ponteiros de dispositivo, um para cada buffer de folha na entrada/saída. Para gerar a lista fixa, iteramos os parâmetros e a saída e, para cada um, fazemos uma travessia da pré-venda do formato. Concretamente:

// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1