Panggilan kustom XLA

Dokumen ini menjelaskan cara menulis dan menggunakan panggilan kustom XLA. Dengan panggilan kustom, Anda dapat memanggil kode yang ditulis dalam bahasa pemrograman seperti C++ atau CUDA dari program XLA.

Membuat panggilan kustom pada CPU

Anda dapat membuat petunjuk HLO yang merepresentasikan panggilan kustom melalui API klien XLA. Misalnya, kode berikut menggunakan panggilan kustom untuk menghitung A[i] = B[i % 128]+ C[i] pada CPU. (Tentu saja Anda bisa – dan harus! – lakukan ini dengan HLO reguler.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);
  const float* in1 = reinterpret_cast<const float*>(in[1]);
  for (int i = 0; i < 2048; ++i) {
    out_buf[i] = in0[i % 128] + in1[i];
  }
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

Perhatikan bahwa fungsi do_custom_call perlu mengetahui dimensi buffer yang beroperasi. Dalam contoh ini, kami melakukan hardcode untuk ukuran 128 dan 2048. Jika tidak ingin melakukannya, Anda dapat meneruskan dimensi sebagai parameter ke panggilan.

Membuat panggilan kustom di GPU

Framework panggilan khusus GPU agak berbeda dari yang ada di CPU. Berikut adalah contoh CUDA yang melakukan komputasi yang sama (A[i] = B[i % 128] + C[i]) seperti kode CPU di atas.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");

Perhatikan terlebih dahulu bahwa fungsi panggilan kustom GPU masih merupakan fungsi yang dijalankan di CPU. Fungsi CPU do_custom_call bertanggung jawab untuk mengantrekan pekerjaan di GPU. Di sini, kernel CUDA meluncurkan kernel CUDA, tetapi juga dapat melakukan hal lain, seperti memanggil cuBLAS.

buffers adalah array pointer yang ada di host, dan setiap elemen yang dimuatnya mengarah ke memori perangkat (yaitu GPU). Parameter akan muncul terlebih dahulu, lalu diikuti dengan nilai output. Hal ini sangat berbeda dari konvensi pemanggilan CPU, yang memiliki dua parameter, ins dan out. Konvensi pemanggilan GPU memungkinkan menangani input/output berbentuk tuple secara efisien.

Seperti dalam contoh CPU, kita telah meng-hardcode ukuran buffer input dan output ke dalam panggilan kustom. Namun, tidak seperti dalam kasus CPU, meneruskan ukuran buffer sebagai operand ke panggilan kustom tidak akan berfungsi dengan baik. Biasanya kita memerlukan ukuran buffer yang tersedia di CPU (misalnya saat meluncurkan kernel, kita perlu mengetahui dimensi blok/grid yang akan digunakan). Namun, jika kita meneruskan ukuran buffer sebagai operand ke panggilan khusus, nilainya akan tersimpan di memori GPU. Sehingga kita harus melakukan memcpy perangkat-ke-host sinkron yang mahal di awal operasi hanya untuk membaca ukuran.

Untuk memudahkan Anda mengatasi hal ini, kami menyediakan parameter opaque. Anda dapat menetapkannya ke string byte arbitrer saat membuat panggilan kustom:

std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
                opaque);

Karena xla::Shape memiliki representasi buffering protokol, Anda dapat menyimpan proto yang diserialisasi ini di dalam opaque dan melakukan deserialisasi dalam panggilan kustom GPU. Namun, perlu diketahui bahwa meskipun xla::ShapeProto tidak sering berubah, tetapi tidak berubah. Periksa log Git untuk melihat perubahannya di masa lalu.

Memberi sinyal error

Jika panggilan kustom mengalami error, Anda dapat memberikan sinyal error tersebut ke runtime XLA (bukannya error atau menampilkan obfuscation di buffer output) dengan menggunakan tanda tangan berikut untuk fungsi Anda:

Di CPU:

#include "xla/service/custom_call_status.h"

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);

di GPU:

#include "xla/service/custom_call_status.h"

void do_custom_call(CUstream stream, void** buffers, const char* opaque,
                    size_t opaque_len, xla::XlaCustomCallStatus* status);

Anda dapat menandakan kegagalan menggunakan XlaCustomCallStatusSetFailure, misalnya:

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
  // ... do some work.

  if (bad_condition) {
    char* error_message = "An error occurred";
    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
    return;
  }

  // ... continue.
}

Anda juga dapat menggunakan XlaCustomCallStatusSetSuccess untuk menunjukkan keberhasilan, tetapi XlaCustomCallStatus dalam status berhasil secara default, sehingga mengabaikannya sepenuhnya juga akan mengindikasikan keberhasilan.

Jika menggunakan fungsi panggilan kustom dengan tanda tangan ini, Anda harus membuat operasi custom-call yang sesuai dengan kumpulan versi API yang sesuai, misalnya:

xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
                opaque, /*has_side_effect=*/false,
                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
                /*api_version=*/API_VERSION_STATUS_RETURNING);

Jika gagal, tidak ada output panggilan kustom yang akan digunakan; runtime XLA akan menghentikan komputasi. Komputasi HLO tidak dapat dipulihkan dari error (misalnya dengan menangkap dan menanganinya).

Meneruskan tuple ke panggilan kustom

Pertimbangkan panggilan kustom berikut.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);

Di CPU dan GPU, tuple direpresentasikan dalam memori sebagai array pointer. Dalam pseudocode C++, parameter 0 di atas ditata sebagai berikut.

// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];

void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;

void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;

Meskipun representasi tuple dalam memori sama di CPU dan GPU, tule tersebut ditangani secara berbeda dalam konvensi panggilan panggilan khusus CPU dan GPU.

Output Tuple sebagai buffering sementara

Input tuple untuk panggilan kustom sangat praktis, tetapi tidak sepenuhnya diperlukan. Jika kami tidak mendukung input tuple ke panggilan kustom, Anda dapat mengekstrak tuple menggunakan get-tuple-element sebelum meneruskannya ke panggilan kustom.

Di sisi lain, output tuple memungkinkan Anda melakukan hal-hal yang tidak dapat Anda lakukan jika tidak.

Alasan pasti untuk memiliki output tuple adalah karena output tuple adalah bagaimana panggilan kustom (atau operasi XLA lainnya) menampilkan beberapa array independen.

Namun, output tuple juga merupakan cara untuk memberikan memori suhu panggilan kustom Anda. Ya, output dapat mewakili buffering sementara. Pertimbangkan, buffer output memiliki properti yang dapat ditulis oleh operasi, dan dapat dibaca darinya setelah ditulis. Itulah yang Anda inginkan dari buffering sementara.

Pada contoh di atas, misalkan kita ingin menggunakan F32[1024] sebagai buffering sementara. Kemudian kita akan menulis HLO seperti di atas, dan kita tidak akan pernah membaca tuple indeks 1 dari output panggilan kustom.

Tuple dalam panggilan kustom CPU

Dalam kode CPU, kita memiliki fungsi do_custom_call(const void** ins, void* out). ins adalah array yang hanya berisi satu elemen, yang menunjuk ke param0. Subbuffer param0 dapat diakses dengan mendereferensi pointer tersebut, dan subbuffer output_tuple dapat diakses dengan mendereferensi out.

Tuple dalam panggilan kustom GPU

Dalam kode GPU, kita memiliki fungsi do_custom_call(..., void** buffers, ...). Dalam hal ini, buffers adalah array host yang berisi enam pointer perangkat, satu untuk setiap leaf buffer dalam input/output. Untuk menghasilkan daftar datar, kami melakukan iterasi terhadap parameter dan output, dan untuk setiap parameter, kami melakukan traversal praorder sesuai bentuknya. Secara konkret:

// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1