Пользовательские звонки XLA

В этом документе описывается, как писать и использовать пользовательские вызовы XLA. Пользовательские вызовы позволяют вызывать из программы XLA код, написанный на языке программирования, таком как C++ или CUDA.

Создайте собственный вызов на ЦП

Вы можете создать инструкцию HLO, которая представляет собой пользовательский вызов через клиентский API XLA. Например, следующий код использует специальный вызов для вычисления A[i] = B[i % 128]+ C[i] на ЦП. (Конечно, вы можете – и должны! – сделать это с помощью обычного HLO.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);
  const float* in1 = reinterpret_cast<const float*>(in[1]);
  for (int i = 0; i < 2048; ++i) {
    out_buf[i] = in0[i % 128] + in1[i];
  }
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

Обратите внимание, что функции do_custom_call необходимо знать размеры буферов, над которыми она работает. В этом примере мы жестко запрограммировали размеры 128 и 2048 . Если вы не хотите этого делать, вы можете передать размеры в качестве параметров вызова.

Создайте собственный вызов на графическом процессоре

Платформа пользовательских вызовов графического процессора несколько отличается от платформы центрального процессора. Вот пример CUDA, который выполняет те же вычисления ( A[i] = B[i % 128] + C[i] ), что и приведенный выше код ЦП.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");

Прежде всего обратите внимание, что пользовательская функция вызова графического процессора по-прежнему является функцией, выполняемой на ЦП . Функция CPU do_custom_call отвечает за постановку в очередь работы на графическом процессоре. Здесь он запускает ядро ​​CUDA, но может делать и что-то еще, например вызывать cuBLAS.

buffers — это массив указателей, который находится на хосте, и каждый содержащийся в нем элемент указывает на память устройства (т. е. графического процессора). Сначала идут параметры, а затем выходное значение. Это заметно отличается от соглашения о вызовах ЦП, которое имеет два параметра: ins и out . Соглашение о вызовах графического процессора позволяет эффективно обрабатывать вводы/выводы в форме кортежей.

Как и в примере с процессором, мы жестко запрограммировали размеры входного и выходного буфера в нашем пользовательском вызове. Однако, в отличие от случая с ЦП, передача размеров буфера в качестве операндов пользовательскому вызову не будет работать должным образом. Обычно нам нужны размеры буфера, доступные нам в ЦП (например, при запуске ядра нам нужно знать размеры блока/сетки, которые мы будем использовать). Но если бы мы передали размеры буфера в качестве операндов нашему пользовательскому вызову, их значения остались бы в памяти графического процессора. Тогда нам пришлось бы выполнить дорогостоящую синхронную memcpy между устройством и хостом в начале нашей операции только для того, чтобы прочитать размеры.

Чтобы обойти эту проблему, мы предоставляем параметр opaque . Вы можете установить для него произвольную строку байтов при создании пользовательского вызова:

std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
                opaque);

Поскольку xla::Shape имеет представление буфера протокола, вы можете сохранить этот сериализованный прототип внутри opaque и десериализовать его в пользовательском вызове графического процессора. Однако обратите внимание: хотя xla::ShapeProto не меняется часто, он меняется . Проверьте журнал Git, чтобы увидеть, как он изменился в прошлом.

Сигнализация об ошибке

Если ваш пользовательский вызов обнаруживает ошибку, вы можете сообщить об ошибке среде выполнения XLA (вместо, например, сбоя или возврата бессмысленной информации в выходных буферах), используя следующую сигнатуру для вашей функции:

На процессоре:

#include "xla/service/custom_call_status.h"

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);

на графическом процессоре:

#include "xla/service/custom_call_status.h"

void do_custom_call(CUstream stream, void** buffers, const char* opaque,
                    size_t opaque_len, xla::XlaCustomCallStatus* status);

Вы можете сигнализировать об ошибке, используя XlaCustomCallStatusSetFailure , например:

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
  // ... do some work.

  if (bad_condition) {
    char* error_message = "An error occurred";
    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
    return;
  }

  // ... continue.
}

Вы также можете использовать XlaCustomCallStatusSetSuccess для обозначения успеха, но XlaCustomCallStatus по умолчанию находится в состоянии успеха, поэтому полное его игнорирование также будет означать успех.

При использовании пользовательских функций вызова с этой сигнатурой необходимо создать соответствующую операцию custom-call с соответствующим набором версий API, например:

xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
                opaque, /*has_side_effect=*/false,
                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
                /*api_version=*/API_VERSION_STATUS_RETURNING);

В случае сбоя ни один из пользовательских выходов вызова не будет использоваться; среда выполнения XLA прекратит вычисления. Вычисление HLO невозможно исправить ошибку (например, путем ее обнаружения и обработки).

Передача кортежей в пользовательские вызовы

Рассмотрим следующий пользовательский вызов.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);

Как на процессоре, так и на графическом процессоре кортеж представлен в памяти как массив указателей. В псевдокоде C++ параметр 0, указанный выше, выглядит следующим образом.

// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];

void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;

void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;

Хотя представление кортежей в памяти одинаково в ЦП и ГП, они обрабатываются по-разному в соглашениях о вызовах пользовательских вызовов ЦП и ГП.

Кортеж выводится как временные буферы

Ввод кортежей в пользовательские вызовы удобен, но не является строго необходимым. Если бы мы не поддерживали ввод кортежей для пользовательских вызовов, вы всегда могли бы распаковать кортежи с помощью get-tuple-element перед передачей их в пользовательский вызов.

С другой стороны, выходные данные кортежа позволяют вам делать то, что иначе вы бы не смогли.

Очевидная причина использования выходных данных кортежа заключается в том, что выходные данные кортежа — это то, как пользовательский вызов (или любая другая операция XLA) возвращает несколько независимых массивов.

Но менее очевидно, что вывод кортежа также является способом выделения временной памяти для вашего пользовательского вызова. Да, выходные данные могут представлять собой временный буфер. Учтите, что выходной буфер имеет свойство, позволяющее оператору писать в него и читать из него после того, как он был записан. Это именно то, что вы хотите от временного буфера.

Предположим, в приведенном выше примере мы хотим использовать F32[1024] в качестве временного буфера. Тогда мы бы написали HLO так же, как указано выше, и просто никогда не читали бы кортеж с индексом 1 вывода пользовательского вызова.

Кортежи в пользовательских вызовах ЦП

В коде ЦП у нас есть функция do_custom_call(const void** ins, void* out) . ins — это массив всего с одним элементом, указывающим на param0 . Подбуферы параметра param0 доступны путем разыменования этого указателя, а подбуферы output_tuple доступны путем разыменования out .

Кортежи в пользовательских вызовах графического процессора

В коде графического процессора у нас есть функция do_custom_call(..., void** buffers, ...) . В этом случае buffers представляют собой хост-массив из шести указателей устройств, по одному для каждого листового буфера во входе/выходе. Чтобы сгенерировать плоский список, мы перебираем параметры и выходные данные и для каждого выполняем предварительный обход его формы. Конкретно:

// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1