Llamadas personalizadas de XLA

En este documento, se describe cómo escribir y usar llamadas personalizadas de XLA. Las llamadas personalizadas te permiten invocar código escrito en un lenguaje de programación, como C++ o CUDA, desde un programa XLA.

Crea una llamada personalizada en la CPU

Puedes crear una instrucción de HLO que represente una llamada personalizada a través de la API de cliente de XLA. Por ejemplo, el siguiente código usa una llamada personalizada para calcular A[i] = B[i % 128]+ C[i] en la CPU. (Por supuesto que podrías, ¡y deberías! haz esto con HLO común).

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);
  const float* in1 = reinterpret_cast<const float*>(in[1]);
  for (int i = 0; i < 2048; ++i) {
    out_buf[i] = in0[i % 128] + in1[i];
  }
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

Ten en cuenta que la función do_custom_call necesita conocer las dimensiones de los búferes sobre los que opera. En este ejemplo, codificamos los tamaños 128 y 2048. Si no quieres hacerlo, puedes pasar las dimensiones como parámetros a la llamada.

Crea una llamada personalizada en GPU

El framework de llamadas personalizadas de GPU es algo diferente al de la CPU. Este es un ejemplo de CUDA que realiza el mismo cálculo (A[i] = B[i % 128] + C[i]) que el código de CPU anterior.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");

Primero, observa que la función de llamada personalizada de GPU sigue siendo una función ejecutada en la CPU. La función do_custom_call de la CPU se encarga de poner el trabajo en cola en la GPU. Aquí, inicia un kernel CUDA, pero también podría hacer otra acción, como llamar a cuBLAS.

buffers es un array de punteros que reside en el host, y cada elemento que contiene apunta a la memoria del dispositivo (es decir, GPU). Los parámetros van primero, seguidos del valor de salida. Esto difiere notablemente de la convención de llamada a la CPU, que tiene dos parámetros, ins y out. La convención de llamada de GPU permite administrar las entradas y salidas con forma de tupla de manera eficiente.

Al igual que en el ejemplo de la CPU, codificamos los tamaños del búfer de entrada y salida en nuestra llamada personalizada. Sin embargo, a diferencia de lo que ocurre en el caso de la CPU, pasar los tamaños de búfer como operandos a la llamada personalizada no funcionaría bien. Por lo general, necesitamos los tamaños de búfer disponibles en la CPU (p.ej., cuando se inicia un kernel, necesitamos conocer las dimensiones de bloque/cuadrícula que se deben usar). Sin embargo, si pasáramos los tamaños del búfer como operandos a nuestra llamada personalizada, sus valores permanecerían en la memoria de la GPU. Entonces, tendríamos que hacer un costoso memcpy síncrono de dispositivo a host al comienzo de nuestra operación solo para leer los tamaños.

Para que puedas solucionar este problema, proporcionamos el parámetro opaque. Puedes establecer esto como una string arbitraria de bytes cuando creas la llamada personalizada:

std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
                opaque);

Debido a que xla::Shape tiene una representación de búfer de protocolo, puedes almacenar este proto serializado dentro de opaque y deserializarlo dentro de tu llamada personalizada de GPU. Sin embargo, ten en cuenta que, aunque xla::ShapeProto no cambia con frecuencia, cambia. Consulta el registro de Git para ver cómo cambió en el pasado.

Cómo indicar un error

Si tu llamada personalizada encuentra un error, puedes indicar el error al entorno de ejecución de XLA (en lugar de, por ejemplo, una falla o un error sin sentido en los búferes de salida), mediante la siguiente firma para tu función:

En la CPU:

#include "xla/service/custom_call_status.h"

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);

en GPU:

#include "xla/service/custom_call_status.h"

void do_custom_call(CUstream stream, void** buffers, const char* opaque,
                    size_t opaque_len, xla::XlaCustomCallStatus* status);

Puedes indicar la falla mediante XlaCustomCallStatusSetFailure, p.ej.:

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
  // ... do some work.

  if (bad_condition) {
    char* error_message = "An error occurred";
    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
    return;
  }

  // ... continue.
}

También puedes usar XlaCustomCallStatusSetSuccess para indicar el éxito, pero XlaCustomCallStatus está en un estado correcto de forma predeterminada, por lo que ignorarlo por completo también indicará el éxito.

Cuando usas funciones de llamada personalizadas con esta firma, debes crear la op custom-call correspondiente con el conjunto de versiones de API adecuado, p.ej.:

xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
                opaque, /*has_side_effect=*/false,
                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
                /*api_version=*/API_VERSION_STATUS_RETURNING);

Si se produce un error, no se usará ninguna de las salidas de llamadas personalizadas; el entorno de ejecución de XLA finalizará el cálculo. No es posible que un cálculo de HLO se recupere del error (p.ej., mediante su captura y su manejo).

Pasa tuplas a llamadas personalizadas

Considera la siguiente llamada personalizada.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);

Tanto en la CPU como en la GPU, una tupla se representa en la memoria como un array de punteros. En el seudocódigo de C++, el parámetro 0 anterior se presenta de la siguiente manera.

// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];

void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;

void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;

Aunque la representación de tuplas en la memoria es la misma en la CPU y la GPU, se manejan de manera diferente en las convenciones de llamada de llamada personalizada de CPU y GPU.

Salidas de tuplas como búferes de temperatura

Las entradas de tuplas a las llamadas personalizadas son convenientes, pero no son estrictamente necesarias. Si no se admitirían entradas de tuplas para las llamadas personalizadas, siempre puedes descomprimir las tuplas con get-tuple-element antes de pasarlas a la llamada personalizada.

Por otro lado, los resultados de tuplas te permiten realizar acciones que de otra manera no podrías realizar.

El motivo obvio para tener salidas de tupla es que esa salida es la forma en que una llamada personalizada (o cualquier otra operación de XLA) muestra varios arreglos independientes.

Pero menos obviamente, el resultado de una tupla también es una forma de darle a tu memoria temporal de llamada personalizada. Sí, un resultado puede representar un búfer temporal. Considera que un búfer de salida tiene la propiedad que la operación puede escribir en él y puede leer después de que se escribe. Eso es exactamente lo que quieres de un búfer temporal.

En el ejemplo anterior, supongamos que queremos usar el F32[1024] como búfer temporal. Luego, se escribe el HLO igual que antes y nunca se lee el índice 1 de la tupla del resultado de la llamada personalizada.

Tuplas en llamadas personalizadas de CPU

En el código de la CPU, tenemos una función do_custom_call(const void** ins, void* out). ins es un array con un solo elemento, que apunta a param0. Se puede acceder a los subbúferes de param0 mediante la desreferencia de ese puntero, y se puede acceder a los subbúferes de output_tuple mediante la desreferencia de out.

Tuplas en llamadas personalizadas de GPU

En el código GPU, tenemos una función do_custom_call(..., void** buffers, ...). En este caso, buffers es un array de host de seis punteros de dispositivo, uno para cada búfer de hoja de la entrada/salida. Para generar una lista plana, iteramos los parámetros y el resultado y, para cada uno, hacemos un recorrido de pedido por adelantado de su forma. Concretamente:

// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1