XLA özel aramaları

Bu dokümanda, XLA özel çağrılarının nasıl yazılacağı ve kullanılacağı açıklanmaktadır. Özel çağrılar, bir XLA programından C++ veya CUDA gibi bir programlama dilinde yazılmış kodu çağırmanıza olanak tanır.

CPU'da özel bir çağrı oluştur

XLA'nın istemci API'si aracılığıyla özel bir çağrıyı temsil eden bir HLO talimatı oluşturabilirsiniz. Örneğin, aşağıdaki kodda CPU üzerinde A[i] = B[i % 128]+ C[i] hesaplaması için özel bir çağrı kullanılmaktadır. (Elbette bunu yapabilirsiniz ve yapmalısınız. bunu normal HLO ile yapın.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);
  const float* in1 = reinterpret_cast<const float*>(in[1]);
  for (int i = 0; i < 2048; ++i) {
    out_buf[i] = in0[i % 128] + in1[i];
  }
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

do_custom_call işlevinin üzerinde çalıştığı arabelleklerin boyutlarını bilmesi gerektiğine dikkat edin. Bu örnekte 128 ve 2048 boyutlarını kodluyoruz. Bunu yapmak istemiyorsanız boyutları çağrıya parametreler olarak aktarabilirsiniz.

GPU'da özel çağrı oluştur

GPU özel çağrı çerçevesi, CPU'dakinden biraz farklıdır. Yukarıdaki CPU koduyla aynı hesaplamayı (A[i] = B[i % 128] + C[i]) yapan bir CUDA örneğini burada bulabilirsiniz.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");

Öncelikle GPU özel çağrı işlevinin hala CPU'da yürütülen bir işlev olduğuna dikkat edin. do_custom_call CPU işlevi, GPU'daki işleri sıraya koymaktan sorumludur. Burada bir CUDA çekirdeğini başlatır, ancak cuBLAS'ı çağırma gibi başka bir işlemi de yapabilir.

buffers, ana makinede bulunan işaretçiler dizisidir ve içerdiği her öğe, cihaz (ör. GPU) belleğine puan içerir. Parametreler önce gelir, ardından çıktı değeri gelir. Bu, iki parametreye sahip olan CPU çağırma kuralından oldukça farklıdır: ins ve out. GPU çağırma kuralı, çift şeklindeki girişlerin/çıkışların verimli bir şekilde işlenmesini sağlar.

CPU örneğinde olduğu gibi, giriş ve çıkış arabelleği boyutlarını özel çağrımıza kod olarak yerleştirdik. Ancak CPU örneğinden farklı olarak, arabellek boyutlarını özel çağrıya işlenenler olarak geçirmek iyi sonuç vermez. Genellikle CPU'da kullanabildiğimiz arabellek boyutlarına ihtiyacımız vardır (ör. bir çekirdeği kullanıma sunarken, kullanılacak blok/ızgara boyutlarını bilmemiz gerekir). Ancak arabellek boyutlarını özel çağrımıza işlenenler olarak geçirirsek değerleri GPU belleğinde kalırdı. Bu durumda sadece boyutları okumak için operasyonun başında pahalı bir eşzamanlı cihazdan ana makineye memcpy işlemi uygulamamız gerekirdi.

Bu sorunu çözebilmeniz için opaque parametresini sunuyoruz. Özel çağrıyı oluştururken bunu rastgele bir bayt dizesine ayarlayabilirsiniz:

std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
                opaque);

xla::Shape, bir protokol arabelleği gösterimine sahip olduğundan bu serileştirilmiş protokolü opaque içinde depolayabilir ve GPU özel çağrınızda seri durumdan çıkarabilirsiniz. Ancak xla::ShapeProto sık sık değişmese de değişir. Geçmişte nasıl değiştiğini görmek için Git günlüğünü kontrol edin.

Hata sinyali

Özel çağrınız bir hatayla karşılaşırsa işleviniz için aşağıdaki imzayı kullanarak hatayı XLA çalışma zamanına bildirebilirsiniz (ör. kilitlenme veya çıkış arabelleklerinde anlamsız sonuçlar döndürmek yerine):

CPU'da:

#include "xla/service/custom_call_status.h"

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);

GPU'da:

#include "xla/service/custom_call_status.h"

void do_custom_call(CUstream stream, void** buffers, const char* opaque,
                    size_t opaque_len, xla::XlaCustomCallStatus* status);

XlaCustomCallStatusSetFailure kullanarak hata sinyalini verebilirsiniz. Örneğin:

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
  // ... do some work.

  if (bad_condition) {
    char* error_message = "An error occurred";
    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
    return;
  }

  // ... continue.
}

Başarıyı belirtmek için XlaCustomCallStatusSetSuccess öğesini de kullanabilirsiniz, ancak XlaCustomCallStatus varsayılan olarak başarılı durumundadır, bu nedenle tamamen göz ardı edilmesi de başarılı olduğu anlamına gelir.

Bu imzayla özel çağrı işlevlerini kullanırken karşılık gelen custom-call işlemini uygun API sürümüyle oluşturmanız gerekir. Örneğin:

xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
                opaque, /*has_side_effect=*/false,
                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
                /*api_version=*/API_VERSION_STATUS_RETURNING);

Hata durumunda, özel çağrı çıkışlarının hiçbiri kullanılmaz. XLA çalışma zamanı hesaplamayı sonlandırır. HLO hesaplamasının hatadan kurtulması (ör. yakalayıp işleyerek) mümkün değildir.

Özel çağrılara tuple aktarma

Aşağıdaki özel çağrıyı değerlendirebilirsiniz.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);

Hem CPU hem de GPU'da bir Tuple, bellekte işaretçi dizisi olarak temsil edilir. C++ sözde kodda, yukarıdaki parametre 0 aşağıdaki gibi yerleştirilir.

// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];

void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;

void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;

Tuple'ların bellek içi gösterimi CPU ve GPU'da aynı olsa da CPU ve GPU özel çağrı çağırma kurallarında farklı şekilde ele alınırlar.

Tuple, geçici arabellek olarak çıkış yapar

Özel çağrılara çift giriş yapmak kolaylık sağlar, ancak kesinlikle gerekli değildir. Özel çağrılara tuple girişi desteklememiş olsaydık bunları özel çağrıya iletmeden önce get-tuple-öğesini kullanarak her zaman paketinden çıkarabiliyordunuz.

Diğer yandan, çift çıkışlar başka şekilde yapamayacağınız şeyleri yapmanıza olanak tanır.

Tuple çıkışlarının olmasının en bariz nedeni, tuple çıkışlarının özel bir çağrının (veya başka herhangi bir XLA operatörünün) birden çok bağımsız dizi döndürme şekli olmasıdır.

Ancak açıkça belirtmek gerekirse Tuple çıkışı, özel çağrı geçici belleğinizi vermenin de bir yoludur. Evet, çıkış geçici bir arabelleği temsil edebilir. Çıktı arabelleğinin, işlemin üzerine yazabileceği özelliğe sahip olduğunu ve yazıldıktan sonra buradan okuyabileceğini düşünün. Sıcaklık arabelleğinden de tam olarak bunu istiyorsun.

Yukarıdaki örnekte, F32[1024] öğesini geçici arabellek olarak kullanmak istediğimizi varsayalım. Daha sonra HLO'yu yukarıdaki gibi yazardık ve özel çağrı çıkışının 1. çift dizinini hiçbir zaman okumayız.

CPU özel çağrılarındaki Tuple sayısı

CPU kodunda do_custom_call(const void** ins, void* out) adlı bir işlev var. ins, tek bir öğe içeren ve param0 değerini gösteren bir dizidir. Bu işaretçiye başvurulmadan param0 alt tamponlarına erişilebilir ve output_tuple alt arabelleklerine, out referansının alınmasıyla erişilebilir.

GPU özel çağrılarındaki Tuple sayısı

GPU kodunda bir do_custom_call(..., void** buffers, ...) işlevimiz vardır. Bu örnekte buffers, giriş/çıkıştaki her yaprak arabelleği için bir adet olmak üzere altı cihaz işaretçisinden oluşan bir ana makine dizisidir. Düz liste oluşturmak için parametreler ve çıktılar üzerinde yinelenir ve her biri için şeklin ön siparişi geçişini yaparız. Somut olarak:

// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1