XLA 自訂呼叫

本文件說明如何編寫及使用 XLA 自訂呼叫。自訂呼叫可讓您透過 XLA 程式叫用以 C++ 或 CUDA 等程式設計語言編寫的程式碼。

在 CPU 上建立自訂呼叫

您可以透過 XLA 的用戶端 API 建立 HLO 指令來表示自訂呼叫。例如,以下程式碼使用自訂呼叫,在 CPU 上計算 A[i] = B[i % 128]+ C[i]。(當然可以,而且應該這麼做!- 使用一般 HLO 執行此功能)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);
  const float* in1 = reinterpret_cast<const float*>(in[1]);
  for (int i = 0; i < 2048; ++i) {
    out_buf[i] = in0[i % 128] + in1[i];
  }
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

請注意,do_custom_call 函式需要知道其運作的緩衝區尺寸。在這個範例中,我們對大小 1282048 進行硬式編碼。如果不想執行這項作業,您可以將維度做為參數傳入呼叫。

在 GPU 上建立自訂呼叫

GPU 自訂呼叫架構與 CPU 上略有不同。以下 CUDA 範例執行的運算與上述 CPU 程式碼相同 (A[i] = B[i % 128] + C[i])。

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");

請注意,GPU 自訂呼叫函式仍是在 CPU 上執行的函式do_custom_call CPU 函式負責將 GPU 上的工作排入佇列。此工具會啟動 CUDA 核心,但也可以執行其他操作,例如呼叫 cuBLAS。

buffers 是位於主機上的指標陣列,其中包含指向裝置 (即 GPU) 記憶體的每個元素。參數會先依序列出,接著是輸出值。這與 CPU 呼叫慣例不同,後者有兩個參數:insout。GPU 呼叫慣例可讓您有效率地處理元組的輸入/輸出內容。

與 CPU 範例一樣,我們已將輸入和輸出緩衝區空間硬式編碼為自訂呼叫。然而,與 CPU 情況不同的是,將緩衝區大小做為運算元傳遞至自訂呼叫將無法正常運作。我們通常會需要 CPU 上的緩衝區大小 (例如,啟動核心時,必須知道要使用的區塊/格線維度)。不過,如果我們將緩衝區大小做為運算元傳遞給自訂呼叫,這些緩衝區值會保存在 GPU 記憶體中。接著,我們需要在作業開始時執行昂貴的同步裝置對代管 memcpy,以便讀取大小。

為協助您解決這個問題,我們提供 opaque 參數。您可以在建立自訂呼叫時,將此設為任意位元組字串:

std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
                opaque);

由於 xla::Shape 有通訊協定緩衝區表示法,因此您可以將這個序列化的 proto 儲存在 opaque 中,並在 GPU 自訂呼叫內去序列化。不過請注意,雖然 xla::ShapeProto 不會經常變更,但「確實」會變更。請查看 Git 記錄,瞭解過去的異動。

指出錯誤

如果自訂呼叫遇到錯誤,您可以在函式中使用下列簽名,將錯誤信號傳送至 XLA 執行階段 (而非在輸出緩衝區中異常終止或傳回不合理):

在 CPU 上:

#include "xla/service/custom_call_status.h"

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);

使用 GPU:

#include "xla/service/custom_call_status.h"

void do_custom_call(CUstream stream, void** buffers, const char* opaque,
                    size_t opaque_len, xla::XlaCustomCallStatus* status);

您可以使用 XlaCustomCallStatusSetFailure 表示失敗,例如:

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
  // ... do some work.

  if (bad_condition) {
    char* error_message = "An error occurred";
    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
    return;
  }

  // ... continue.
}

您也可以使用 XlaCustomCallStatusSetSuccess 表示成功,但 XlaCustomCallStatus 預設會處於成功狀態,因此完全忽略也可以表示成功。

搭配這個簽章使用自訂呼叫函式時,您必須建立適當的 custom-call 運算,並提供適當的 API 版本集,例如:

xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
                opaque, /*has_side_effect=*/false,
                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
                /*api_version=*/API_VERSION_STATUS_RETURNING);

失敗時,系統不會使用任何自訂呼叫輸出;XLA 執行階段將終止計算。HLO 運算無法從錯誤中復原 (例如擷取及處理錯誤)。

將元組傳送至自訂呼叫

請考慮使用下列自訂呼叫。

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);

在 CPU 和 GPU 上,元組會以指標陣列的形式在記憶體中表示。在 C++ 虛擬程式碼中,上方的參數 0 呈現如下。

// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];

void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;

void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;

雖然 CPU 和 GPU 在記憶體內配置的表示法相同,但 CPU 和 GPU 自訂呼叫呼叫慣例的處理方式不同。

以臨時緩衝區形式輸出的元組輸出內容

自訂呼叫的元組輸入很方便,但其實沒有必要性。如果我們不支援自訂呼叫元組輸入內容,您可以先使用 get-tuple-element 將元組解壓縮,再將其傳送至自訂呼叫。

另一方面,元組的「輸出」可讓您執行原本無法執行的操作。

出現元組輸出的明顯原因是,元組輸出是自訂呼叫 (或任何其他 XLA 運算) 傳回多個獨立陣列的方式。

但較不明顯,元組輸出也是提供自訂呼叫暫存記憶體的方法。可以,「輸出」可以代表暫存緩衝區。請考慮到輸出緩衝區,其中含有運算可寫入的屬性,且可在寫入後從輸出緩衝區讀取。這就是您需要的臨時緩衝區。

在上述範例中,假設我們想使用 F32[1024] 做為暫存緩衝區。接下來,我們會如上所示寫入 HLO,而我們絕不會讀取自訂呼叫輸出內容的元組索引 1。

CPU 自訂呼叫中的元組

在 CPU 程式碼中,我們有 do_custom_call(const void** ins, void* out) 函式。ins 是一個陣列,只有一個元素,而該元素指向 param0。可解開該指標即可存取 param0 的子緩衝區,而宣告 out 可存取 output_tuple 的子緩衝區。

GPU 自訂呼叫中的元組

在 GPU 程式碼中,我們有 do_custom_call(..., void** buffers, ...) 函式。在此情況下,buffers6 裝置指標的主機陣列,其中每個分葉緩衝區在輸入/輸出中各有一個。為產生平面清單,我們會反覆處理參數和輸出內容,並且針對每個參數進行預購。具體做法:

// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1