XLA कस्टम कॉल

इस दस्तावेज़ में XLA कस्टम कॉल लिखने और इस्तेमाल करने का तरीका बताया गया है. कस्टम कॉल की मदद से, C++ या CUDA जैसी प्रोग्रामिंग भाषा में लिखे गए कोड को किसी XLA प्रोग्राम से इस्तेमाल किया जा सकता है.

सीपीयू पर पसंद के मुताबिक कॉल बनाएं

आपके पास एचएलओ निर्देश बनाने का विकल्प है जो XLA के क्लाइंट एपीआई के ज़रिए कस्टम कॉल दिखाता है. उदाहरण के लिए, यहां दिया गया कोड, सीपीयू पर A[i] = B[i % 128]+ C[i] को कंप्यूट करने के लिए कस्टम कॉल का इस्तेमाल करता है. (बेशक, आपको यह करना चाहिए - और करना चाहिए! – ऐसा नियमित एचएलओ के साथ करें.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);
  const float* in1 = reinterpret_cast<const float*>(in[1]);
  for (int i = 0; i < 2048; ++i) {
    out_buf[i] = in0[i % 128] + in1[i];
  }
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

ध्यान दें कि do_custom_call फ़ंक्शन को उन बफ़र के डाइमेंशन के बारे में पता होना चाहिए जिन पर यह काम करता है. इस उदाहरण में, हमने 128 और 2048 साइज़ को हार्डकोड किया है. अगर आपको ऐसा नहीं करना है, तो कॉल में डाइमेंशन को पैरामीटर के तौर पर पास किया जा सकता है.

जीपीयू पर अपनी पसंद के मुताबिक कॉल करें

जीपीयू का कस्टम कॉल फ़्रेमवर्क, सीपीयू वाले फ़्रेमवर्क से कुछ अलग होता है. यहां सीयूडीए का एक उदाहरण दिया गया है, जो ऊपर दिए गए सीपीयू कोड की तरह ही कंप्यूटेशन (A[i] = B[i % 128] + C[i]) करता है.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");

सबसे पहले ध्यान दें कि जीपीयू कस्टम कॉल फ़ंक्शन अब भी एक फ़ंक्शन है, जो सीपीयू पर चलाया जाता है. जीपीयू पर काम की सूची बनाने के लिए do_custom_call सीपीयू फ़ंक्शन ज़िम्मेदार है. यहां यह एक CUDA कर्नेल दिखाता है, लेकिन यह कुछ और काम भी कर सकता है, जैसे कि Call cuBLAS.

buffers, पॉइंटर का कलेक्शन है, जो होस्ट पर मौजूद होता है. साथ ही, इसमें मौजूद हर एलिमेंट में डिवाइस की जानकारी (यानी कि जीपीयू) मेमोरी होती है. पैरामीटर पहले आते हैं और उससे आउटपुट वैल्यू आती है. यह सीपीयू कॉलिंग कन्वेंशन से खास तौर पर अलग है, जिसमें दो पैरामीटर, ins और out होते हैं. जीपीयू कॉलिंग कन्वेंशन की वजह से, टपल के आकार के इनपुट/आउटपुट को बेहतर तरीके से मैनेज किया जा सकता है.

जैसा कि सीपीयू के उदाहरण में बताया गया है, हमने इनपुट और आउटपुट बफ़र साइज़ को अपने कस्टम कॉल में हार्डकोड किया है. हालांकि, सीपीयू (CPU) के मामले में, बफ़र साइज़ को कस्टम कॉल में ऑपरेटर के रूप में पास करना अच्छा काम नहीं करता. आम तौर पर, हमें सीपीयू पर उपलब्ध बफ़र साइज़ की ज़रूरत होती है. उदाहरण के लिए, कर्नेल को लॉन्च करते समय, हमें इस्तेमाल के लिए ब्लॉक/ग्रिड डाइमेंशन की जानकारी होनी चाहिए. हालांकि, अगर हम कस्टम कॉल में बफ़र साइज़ को ऑपरेटर के तौर पर पास करें, तो उनकी वैल्यू जीपीयू मेमोरी में सेव होंगी. इसके बाद हमें डिवाइस के साइज़ को देखने के लिए, अपने ऑपरेशन की शुरुआत में एक महंगा सिंक्रोनस डिवाइस-टू-होस्ट memcpy बनाना होगा.

आप इससे बेहतर काम कर सकें, इसके लिए हम opaque पैरामीटर देते हैं. कस्टम कॉल करते समय, इसे बाइट की आर्बिट्रेरी स्ट्रिंग पर सेट किया जा सकता है:

std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
                opaque);

xla::Shape में प्रोटोकॉल बफ़र दिखता है. इसलिए, इस प्रोटो को opaque के अंदर स्टोर किया जा सकता है. साथ ही, इसे अपने जीपीयू कस्टम कॉल में डीसीरियलाइज़ किया जा सकता है. हालांकि, ध्यान दें कि xla::ShapeProto बार-बार नहीं बदलता, लेकिन यह बदलता है. पहले में हुए बदलावों के बारे में जानने के लिए, Git लॉग देखें.

किसी गड़बड़ी का सिग्नल देना

अगर आपके कस्टम कॉल को कोई गड़बड़ी मिलती है, तो अपने फ़ंक्शन के लिए इस सिग्नेचर का इस्तेमाल करके, XLA रनटाइम (उदाहरण के लिए, आउटपुट बफ़र में क्रैश होना या कोई जानकारी न देना) को गड़बड़ी का सिग्नल भेजा जा सकता है:

सीपीयू पर:

#include "xla/service/custom_call_status.h"

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);

जीयूपी पर:

#include "xla/service/custom_call_status.h"

void do_custom_call(CUstream stream, void** buffers, const char* opaque,
                    size_t opaque_len, xla::XlaCustomCallStatus* status);

XlaCustomCallStatusSetFailure का इस्तेमाल करके सिग्नल नहीं दिया जा सकता है, उदाहरण के लिए:

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
  // ... do some work.

  if (bad_condition) {
    char* error_message = "An error occurred";
    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
    return;
  }

  // ... continue.
}

सफलता को दिखाने के लिए XlaCustomCallStatusSetSuccess का भी इस्तेमाल किया जा सकता है, लेकिन डिफ़ॉल्ट रूप से, XlaCustomCallStatus की स्थिति 'सफल स्थिति' में है. इसलिए, इसे पूरी तरह से अनदेखा करने का मतलब यह भी होगा कि आपको सफलता मिली है.

इस सिग्नेचर के साथ कस्टम कॉल फ़ंक्शन का इस्तेमाल करते समय, आपको एपीआई वर्शन के सही वर्शन के साथ उससे जुड़ा custom-call ऑपरेटर बनाना होगा, उदाहरण:

xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
                opaque, /*has_side_effect=*/false,
                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
                /*api_version=*/API_VERSION_STATUS_RETURNING);

ऐसा न होने पर, किसी भी कस्टम कॉल आउटपुट का इस्तेमाल नहीं किया जाएगा. XLA रनटाइम, कंप्यूटेशन को खत्म कर देगा. एचएलओ कंप्यूटेशन के ज़रिए गड़बड़ी ठीक करना संभव नहीं है (उदाहरण के लिए, उसे पकड़ना और ठीक करना).

टूल को कस्टम कॉल में भेजना

नीचे दिए गए कस्टम कॉल को आज़माएं.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);

सीपीयू और जीपीयू, दोनों में टपल को मेमोरी में पॉइंटर के ऐरे के तौर पर दिखाया जाता है. C++ pseudocode में, ऊपर दिए गए पैरामीटर 0 को इस तरह तय किया गया है.

// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];

void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;

void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;

सीपीयू और जीपीयू में, मेमोरी में टूल की परफ़ॉर्मेंस एक जैसी होती है. हालांकि, इन्हें सीपीयू और जीपीयू कस्टम-कॉल कॉल के तरीकों में अलग-अलग तरीके से मैनेज किया जाता है.

अस्थायी बफ़र के तौर पर टपल आउटपुट

कस्टम कॉल के लिए टपल इनपुट एक सुविधा है, लेकिन ये बहुत ज़रूरी नहीं हैं. अगर हम कस्टम कॉल के लिए टपल इनपुट की सुविधा नहीं देते हैं, तो कस्टम कॉल में पास करने से पहले गेट-टपल-एलिमेंट का इस्तेमाल करके टूल को अनपैक करें.

वहीं दूसरी ओर, टपल आउटपुट आपको वे काम करने देते हैं जो किसी दूसरे तरीके से नहीं किए जा सकते.

टपल आउटपुट होने की साफ़ वजह यह होती है कि टपल आउटपुट, इस तरह होता है कि एक कस्टम कॉल (या कोई भी दूसरा XLA op) टूल कई इंडिपेंडेंट अरे कैसे देता है.

टपल आउटपुट से भी कस्टम कॉल टेंपरेचर मेमोरी मिल सकती है. हां, आउटपुट अस्थायी बफ़र हो सकता है. देखें कि आउटपुट बफ़र में ऐसी प्रॉपर्टी है जिसे ऑपरेटर, उस कॉन्टेंट में लिख सकता है. साथ ही, आउटपुट होने के बाद, वह डेटा इस डेटा से पढ़ सकता है. अस्थायी बफ़र से आपको यही काम चाहिए.

ऊपर दिए गए उदाहरण में, मान लें कि हम F32[1024] को अस्थायी बफ़र के तौर पर इस्तेमाल करना चाहते हैं. फिर हम ऊपर की तरह एचएलओ लिखेंगे और कस्टम कॉल के आउटपुट का टपल इंडेक्स 1 कभी नहीं पढ़ा जाएगा.

सीपीयू (CPU) कस्टम कॉल में टपल

सीपीयू कोड में, हमारे पास एक do_custom_call(const void** ins, void* out) फ़ंक्शन है. ins सिर्फ़ एक एलिमेंट वाली कलेक्शन है, जो param0 के बारे में बताता है. param0 के सबफ़र को उस पॉइंटर का रेफ़रंस देकर ऐक्सेस किया जा सकता है. साथ ही, output_tuple के सबबर को out से हटाकर ऐक्सेस किया जा सकता है.

जीपीयू कस्टम कॉल में टपल

जीपीयू कोड में, एक do_custom_call(..., void** buffers, ...) फ़ंक्शन मौजूद है. इस मामले में, buffers, छह डिवाइस पॉइंटर की होस्ट कलेक्शन है. इनपुट/आउटपुट में लीफ़ के हर बफ़र के लिए एक होस्ट कलेक्शन है. फ़्लैट सूची जनरेट करने के लिए, हम पैरामीटर और आउटपुट की फिर से जांच करते हैं. साथ ही, हम हर एक के लिए अपने आकार का पहले से ऑर्डर करने का ट्रेवर्सल बनाते हैं. सटीक:

// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1