Chiamate personalizzate XLA

Questo documento descrive come scrivere e utilizzare le chiamate personalizzate XLA. Le chiamate personalizzate consentono di richiamare codice scritto in un linguaggio di programmazione come C++ o CUDA da un programma XLA.

Crea una chiamata personalizzata sulla CPU

Puoi creare un'istruzione HLO che rappresenta una chiamata personalizzata tramite l'API client di XLA. Ad esempio, il codice seguente utilizza una chiamata personalizzata per calcolare A[i] = B[i % 128]+ C[i] sulla CPU. (Ovviamente potresti e dovresti! (puoi farlo con un normale HLO).

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);
  const float* in1 = reinterpret_cast<const float*>(in[1]);
  for (int i = 0; i < 2048; ++i) {
    out_buf[i] = in0[i % 128] + in1[i];
  }
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

Nota che la funzione do_custom_call deve conoscere le dimensioni dei buffer su cui opera. In questo esempio, le dimensioni 128 e 2048 sono impostate come hardcoded. Se non vuoi farlo, puoi passare le dimensioni come parametri alla chiamata.

Crea una chiamata personalizzata su GPU

Il framework delle chiamate personalizzate della GPU è leggermente diverso da quello della CPU. Ecco un esempio CUDA che esegue lo stesso calcolo (A[i] = B[i % 128] + C[i]) del codice della CPU riportato sopra.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");

Tieni presente che la funzione di chiamata personalizzata GPU è ancora una funzione eseguita sulla CPU. La funzione della CPU do_custom_call è responsabile dell'accodamento del lavoro sulla GPU. Qui viene avviato un kernel CUDA, ma potrebbe anche fare qualcos'altro, ad esempio chiamare cuBLAS.

buffers è un array di puntatori che si trovano sull'host e ogni elemento che contiene rimanda alla memoria del dispositivo (ovvero la GPU). I parametri vengono per primi, seguiti dal valore di output. È notevolmente diversa dalla convenzione di chiamata della CPU, che ha due parametri, ins e out. La convenzione di chiamata GPU consente di gestire in modo efficiente input/output a forma di tupla.

Come nell'esempio della CPU, abbiamo inserito come hardcoded le dimensioni del buffer di input e di output nella chiamata personalizzata. Tuttavia, a differenza del caso della CPU, il passaggio delle dimensioni del buffer come operatori alla chiamata personalizzata non funzionerebbe bene. Di solito abbiamo bisogno delle dimensioni del buffer disponibili sulla CPU (ad esempio, quando si avvia un kernel, dobbiamo conoscere le dimensioni del blocco/grid da utilizzare). Ma se passassimo le dimensioni del buffer come operandi alla nostra chiamata personalizzata, i loro valori riterrebbero nella memoria GPU. Dovremmo quindi eseguire un costoso memcpy dispositivo-host sincrono all'inizio dell'operazione solo per leggere le dimensioni.

Per aiutarti a risolvere il problema, forniamo il parametro opaque. Puoi impostarlo su una stringa arbitraria di byte quando crei la chiamata personalizzata:

std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
                opaque);

Poiché xla::Shape ha una rappresentazione di buffer di protocollo, puoi archiviare questo proto in serie all'interno di opaque e deserializzarlo all'interno della tua chiamata GPU personalizzata. Tuttavia, tieni presente che, sebbene xla::ShapeProto non cambi spesso, cambia. Controlla il log Git per vedere come è cambiato in passato.

Segnalazione di un errore

Se si verifica un errore durante la chiamata personalizzata, puoi segnalare l'errore al runtime XLA (anziché, ad esempio, causare un arresto anomalo o restituire sciocchezze nei buffer di output) utilizzando la seguente firma per la tua funzione:

Su CPU:

#include "xla/service/custom_call_status.h"

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);

sulla GPU:

#include "xla/service/custom_call_status.h"

void do_custom_call(CUstream stream, void** buffers, const char* opaque,
                    size_t opaque_len, xla::XlaCustomCallStatus* status);

Puoi segnalare errori utilizzando XlaCustomCallStatusSetFailure, ad esempio:

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
  // ... do some work.

  if (bad_condition) {
    char* error_message = "An error occurred";
    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
    return;
  }

  // ... continue.
}

Puoi anche utilizzare XlaCustomCallStatusSetSuccess per indicare l'esito positivo, ma lo stato di XlaCustomCallStatus per impostazione predefinita è riuscito, quindi anche ignorarlo completamente indica che l'operazione è riuscita.

Quando utilizzi funzioni di chiamata personalizzate con questa firma, devi creare l'operazione custom-call corrispondente con la versione dell'API appropriata, ad esempio:

xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
                opaque, /*has_side_effect=*/false,
                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
                /*api_version=*/API_VERSION_STATUS_RETURNING);

In caso di errore, non verrà utilizzato nessuno degli output di chiamate personalizzate; il runtime XLA terminerà il calcolo. Un calcolo HLO non può recuperare dall'errore (ad esempio rilevandolo e gestendolo).

Passaggio delle tuple alle chiamate personalizzate

Prendi in considerazione la seguente chiamata personalizzata.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);

Sia su CPU che su GPU, una tupla è rappresentata in memoria come un array di puntatori. Nello pseudocodice C++, il parametro 0 riportato sopra è strutturato come segue.

// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];

void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;

void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;

Sebbene la rappresentazione in memoria delle tuple sia la stessa in CPU e GPU, vengono gestite in modo diverso nelle convenzioni di chiamata personalizzata per CPU e GPU.

La tuple esce come buffer temporanei

Gli input di tuple per le chiamate personalizzate sono comodi, ma non sono strettamente necessari. Se non supportavamo gli input tuple per le chiamate personalizzate, puoi sempre decomprimere le tuple utilizzando get-tuple-element prima di passarle alla chiamata personalizzata.

D'altra parte, gli output a tuple ti consentono di eseguire operazioni che altrimenti non avresti potuto eseguire.

L'ovvio motivo per avere output a tuple è che gli output a tuple sono il modo in cui una chiamata personalizzata (o qualsiasi altra operazione XLA) restituisce più array indipendenti.

Ma meno, ovviamente, l'output a tuple è anche un modo per assegnare la memoria temporanea delle chiamate. Sì, un output può rappresentare un buffer temporaneo. Tieni presente che un buffer di output ha la proprietà che l'operatore può scrivere e può leggerlo dopo la scrittura. Questo è esattamente ciò che vuoi da un buffer di temperatura.

Nell'esempio precedente, supponiamo di voler utilizzare F32[1024] come buffer temporaneo. Quindi scriveremmo l'HLO come indicato sopra e non avremmo mai letto l'indice tuple 1 dell'output della chiamata personalizzata.

Tuple nelle chiamate personalizzate della CPU

Nel codice della CPU, abbiamo una funzione do_custom_call(const void** ins, void* out). ins è un array con un solo elemento che rimanda a param0. I subbuffer di param0 sono accessibili dereferenziando il puntatore, mentre i subbuffer di output_tuple sono accessibili dereferendo out.

Tuple nelle chiamate personalizzate GPU

Nel codice GPU, abbiamo una funzione do_custom_call(..., void** buffers, ...). In questo caso buffers è un array host di sei puntatori dispositivo, uno per ogni buffer foglia nell'input/output. Per generare l'elenco semplice, eseguiamo l'iterazione dei parametri e dell'output e per ognuno eseguiamo un attraversamento di pre-ordine della sua forma. Concretamente:

// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1