Este documento descreve como criar e usar chamadas personalizadas do XLA. As chamadas personalizadas permitem invocar o código escrito em uma linguagem de programação, como C++ ou CUDA, de um programa do XLA.
Criar uma chamada personalizada na CPU
É possível criar uma instrução HLO que representa uma chamada personalizada por meio da API cliente
do XLA. Por exemplo, o código a seguir usa uma chamada personalizada para calcular A[i] = B[i %
128]+ C[i]
na CPU. (É claro que você pode - e deve! – faça isso
com HLO normal.)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}
void do_custom_call(void* out, const void** in) {
float* out_buf = reinterpret_cast<float*>(out);
const float* in0 = reinterpret_cast<const float*>(in[0]);
const float* in1 = reinterpret_cast<const float*>(in[1]);
for (int i = 0; i < 2048; ++i) {
out_buf[i] = in0[i % 128] + in1[i];
}
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");
A função do_custom_call
precisa saber as dimensões dos
buffers em que opera. Neste exemplo, fixamos no código os tamanhos 128
e
2048
. Se não quiser fazer isso, transmita as dimensões como
parâmetros para a chamada.
Criar uma chamada personalizada na GPU
O framework de chamada personalizada da GPU é um pouco diferente do que está na CPU. Veja
um exemplo de CUDA que faz o mesmo cálculo (A[i] = B[i % 128] + C[i]
) que
o código de CPU acima.
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, void** buffers,
const char* opaque, size_t opaque_len) {
const float* in0 = reinterpret_cast<const float*>(buffers[0]);
const float* in1 = reinterpret_cast<const float*>(buffers[1]);
float* out = reinterpret_cast<float*>(buffers[2]);
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim,
/*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");
Primeiro, observe que a função de chamada personalizada da GPU ainda é uma função executada na CPU. A função da CPU do_custom_call
é responsável por enfileirar o trabalho
na GPU. Aqui, ele inicia um kernel CUDA, mas também pode fazer outra coisa,
como chamar cuBLAS.
buffers
é uma matriz de ponteiros que reside no host, e cada elemento nele
contém pontos para a memória do dispositivo (ou seja, GPU). Os parâmetros vêm primeiro, seguidos
pelo valor de saída. Isso é notavelmente diferente da convenção de chamada da CPU,
que tem dois parâmetros, ins
e out
. A convenção de chamada da GPU possibilita o processamento eficiente de entradas/saídas em forma de tupla.
Como no exemplo da CPU, os tamanhos dos buffers de entrada e saída foram fixados na
chamada personalizada. No entanto, ao contrário do caso da CPU, transmitir os tamanhos de buffer como
operadores para a chamada personalizada não funcionaria bem. Normalmente, precisamos dos tamanhos
de buffer disponíveis na CPU. Por exemplo, ao iniciar um kernel, precisamos saber
as dimensões de bloco/grade a serem usadas. No entanto, se transmitissemos os tamanhos de buffer como
operadores para nossa chamada personalizada, os valores deles ficariam na memória da GPU. Nesse caso,
precisaríamos fazer uma memcpy
síncrona de dispositivo para host de alto custo no início da nossa
operação apenas para ler os tamanhos.
Para contornar esse problema, fornecemos o parâmetro opaque
. Você pode defini-lo como uma string arbitrária de bytes ao criar a chamada personalizada:
std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
opaque);
Como xla::Shape
tem uma representação de buffer de protocolo, é possível armazenar esse
proto serializado dentro de opaque
e desserializá-lo na chamada personalizada
da GPU. No entanto, embora xla::ShapeProto
não mude com frequência,
ele muda. Verifique o registro do Git para ver as mudanças no passado.
Como sinalizar um erro
Se a chamada personalizada encontrar um erro, você poderá sinalizá-lo para o ambiente de execução do XLA (em vez de gerar falhas ou retornar erros nos buffers de saída) usando a seguinte assinatura para a função:
Na CPU:
#include "xla/service/custom_call_status.h"
void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);
Na GPU:
#include "xla/service/custom_call_status.h"
void do_custom_call(CUstream stream, void** buffers, const char* opaque,
size_t opaque_len, xla::XlaCustomCallStatus* status);
Você pode sinalizar falhas usando XlaCustomCallStatusSetFailure
, por exemplo:
void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
// ... do some work.
if (bad_condition) {
char* error_message = "An error occurred";
XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
return;
}
// ... continue.
}
Também é possível usar XlaCustomCallStatusSetSuccess
para indicar sucesso, mas o
XlaCustomCallStatus
está com um estado de sucesso por padrão. Portanto, ignorá-lo
completamente também indica sucesso.
Ao usar funções de chamada personalizadas com essa assinatura, é necessário criar a
operação custom-call
correspondente com o conjunto de versões da API apropriado, por exemplo:
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
opaque, /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/API_VERSION_STATUS_RETURNING);
Em caso de falha, nenhuma das saídas de chamada personalizadas será usada. O ambiente de execução do XLA encerrará o cálculo. Não é possível que um cálculo de HLO seja recuperado do erro (por exemplo, capturando e processando-o).
Como transmitir tuplas para chamadas personalizadas
Considere a chamada personalizada a seguir.
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);
Na CPU e na GPU, uma tupla é representada na memória como uma matriz de ponteiros. No pseudocódigo C++, o parâmetro 0 acima é mostrado da seguinte maneira.
// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];
void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;
void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;
Embora a representação na memória de tuplas seja a mesma na CPU e na GPU, elas são tratadas de forma diferente nas convenções de chamada de chamada personalizada da CPU e da GPU.
Saídas de tuplas como buffers de temperatura
As entradas de tuplas para chamadas personalizadas são uma conveniência, mas não são estritamente necessárias. Caso as entradas de tupla não sejam compatíveis com chamadas personalizadas, você pode descompactar as tuplas usando get-tuple-element antes de transmiti-las para a chamada personalizada.
Por outro lado, com as outputs da tupla, é possível fazer coisas que não seria possível de outra forma.
A razão óbvia para ter saídas de tupla é como uma chamada personalizada (ou qualquer outra op do XLA) retorna várias matrizes independentes.
Mas, menos óbvio, uma saída de tupla também é uma maneira de fornecer memória temporária de chamada personalizada. Sim, uma saída pode representar um buffer temporário. Considere que um buffer de saída tem a propriedade em que a operação pode gravar e pode ler a partir dele depois de gravado. Isso é exatamente o que você quer de um buffer temporário.
No exemplo acima, suponha que você queira usar a F32[1024]
como um buffer temporário.
Em seguida, criaríamos o HLO como mostrado acima, e nunca leríamos o índice de tupla 1
da saída da chamada personalizada.
Tuplas em chamadas personalizadas da CPU
No código da CPU, temos uma função do_custom_call(const void** ins, void* out)
.
ins
é uma matriz com apenas um elemento, que aponta para param0
. Os
subbuffers de param0
são acessíveis desreferenciando esse ponteiro, e os
subbuffers de output_tuple
são acessíveis ao desreferenciar out
.
Tuplas em chamadas personalizadas da GPU
No código da GPU, temos uma função do_custom_call(..., void** buffers, ...)
. Nesse
caso, buffers
é uma matriz de host de seis ponteiros de dispositivo, um para cada buffer
de folha na entrada/saída. Para gerar a lista fixa, iteramos os parâmetros e a saída e, para cada um, fazemos uma travessia da pré-venda do formato.
Concretamente:
// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1