Appels personnalisés XLA

Ce document explique comment écrire et utiliser des appels personnalisés XLA. Les appels personnalisés vous permettent d'appeler du code écrit dans un langage de programmation tel que C++ ou CUDA à partir d'un programme XLA.

Créer un appel personnalisé sur le processeur

Vous pouvez créer une instruction HLO qui représente un appel personnalisé via l'API cliente de XLA. Par exemple, le code suivant utilise un appel personnalisé pour calculer A[i] = B[i % 128]+ C[i] sur le processeur. (Vous pourriez bien sûr ! utilisez l'appel HLO standard.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);
  const float* in1 = reinterpret_cast<const float*>(in[1]);
  for (int i = 0; i < 2048; ++i) {
    out_buf[i] = in0[i % 128] + in1[i];
  }
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

Notez que la fonction do_custom_call doit connaître les dimensions des tampons sur lesquels elle opère. Dans cet exemple, nous codons en dur les tailles 128 et 2048. Si vous ne souhaitez pas le faire, vous pouvez transmettre les dimensions en tant que paramètres à l'appel.

Créer un appel personnalisé sur le GPU

Le framework d'appel personnalisé GPU est quelque peu différent de celui sur le processeur. Voici un exemple CUDA qui effectue le même calcul (A[i] = B[i % 128] + C[i]) que le code du processeur ci-dessus.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");

Notez d'abord que la fonction d'appel personnalisé GPU est toujours une fonction exécutée sur le processeur. La fonction CPU do_custom_call est responsable de la mise en file d'attente des tâches sur le GPU. Ici, il lance un noyau CUDA, mais il pourrait également faire autre chose, comme appeler cuBLAS.

buffers est un tableau de pointeurs résidant sur l'hôte et chaque élément qu'il contient pointe vers la mémoire de l'appareil (c'est-à-dire le GPU). Les paramètres sont placés en premier, suivis de la valeur de sortie. Cela diffère particulièrement de la convention d'appel du processeur, qui comporte deux paramètres : ins et out. La convention d'appel GPU permet de gérer efficacement les entrées/sorties de type tuple.

Comme dans l'exemple du processeur, nous avons codé en dur les tailles des tampons d'entrée et de sortie dans notre appel personnalisé. Toutefois, contrairement au cas du processeur, la transmission des tailles de mémoire tampon sous forme d'opérandes à l'appel personnalisé ne fonctionne pas correctement. Nous avons généralement besoin des tailles de tampon disponibles sur le processeur (par exemple, lors du lancement d'un noyau, nous devons connaître les dimensions de bloc/grille à utiliser). Toutefois, si nous transmettions les tailles de mémoire tampon en tant qu'opérandes à notre appel personnalisé, leurs valeurs seraient stockées dans la mémoire du GPU. Il faudrait alors exécuter un memcpy synchrone d'appareil à hôte coûteux au début de l'opération juste pour lire les tailles.

Pour vous permettre d'y remédier, nous fournissons le paramètre opaque. Vous pouvez le définir sur une chaîne arbitraire d'octets lorsque vous créez l'appel personnalisé:

std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
                opaque);

Comme xla::Shape dispose d'une représentation de tampon de protocole, vous pouvez stocker ce proto sérialisé dans opaque et le désérialiser dans votre appel personnalisé GPU. Notez toutefois que même si xla::ShapeProto ne change pas fréquemment, il change. Consultez le journal Git pour voir comment il a changé dans le passé.

Signalement d'une erreur

Si votre appel personnalisé rencontre une erreur, vous pouvez la signaler à l'environnement d'exécution XLA (au lieu de planter ou renvoyer un message absurde dans les tampons de sortie, par exemple) en utilisant la signature suivante pour votre fonction:

Sur le processeur:

#include "xla/service/custom_call_status.h"

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);

Sur GPU:

#include "xla/service/custom_call_status.h"

void do_custom_call(CUstream stream, void** buffers, const char* opaque,
                    size_t opaque_len, xla::XlaCustomCallStatus* status);

Vous pouvez signaler une défaillance à l'aide de XlaCustomCallStatusSetFailure, par exemple :

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
  // ... do some work.

  if (bad_condition) {
    char* error_message = "An error occurred";
    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
    return;
  }

  // ... continue.
}

Vous pouvez également utiliser XlaCustomCallStatusSetSuccess pour indiquer la réussite de l'opération, mais XlaCustomCallStatus est en état de réussite par défaut. Si vous l'ignorez complètement, cela signifie donc également que l'opération a réussi.

Lorsque vous utilisez des fonctions d'appel personnalisées avec cette signature, vous devez créer l'opération custom-call correspondante avec la version d'API appropriée, par exemple :

xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
                opaque, /*has_side_effect=*/false,
                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
                /*api_version=*/API_VERSION_STATUS_RETURNING);

En cas d'échec, aucune des sorties d'appel personnalisées n'est utilisée. L'environnement d'exécution XLA met fin au calcul. Un calcul HLO ne peut pas récupérer de l'erreur (par exemple, en l'interceptant et en la gérant).

Transmettre des tuples aux appels personnalisés

Prenons l'appel personnalisé suivant.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);

Sur le processeur et le GPU, un tuple est représenté dans la mémoire sous la forme d'un tableau de pointeurs. Dans le pseudo-code C++, le paramètre 0 ci-dessus est présenté comme suit.

// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];

void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;

void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;

Bien que la représentation en mémoire des tuples soit la même pour le processeur et le GPU, ils sont traités différemment dans les conventions d'appel personnalisé du processeur et du GPU.

Sorties du tuple en tant que tampons temporaires

Les entrées tuples des appels personnalisés sont pratiques, mais elles ne sont pas strictement nécessaires. Si nous ne prenons pas en charge les entrées de tuple pour les appels personnalisés, vous pouvez toujours décompresser les tuples à l'aide de get-tuple-element avant de les transmettre à l'appel personnalisé.

En revanche, les tuples outputs vous permettent d'effectuer des opérations impossibles à effectuer autrement.

La raison évidente d'avoir des sorties de tuples est que les sorties de tuples correspondent à la manière dont un appel personnalisé (ou toute autre opération XLA) renvoie plusieurs tableaux indépendants.

Mais, de manière moins évidente, une sortie tuple est également un moyen de fournir une mémoire temporaire à votre appel personnalisé. Oui, une sortie peut représenter un tampon temporaire. Considérons qu'un tampon de sortie possède la propriété que l'opération peut y écrire, et qu'elle peut lire après y avoir été écrite. C'est exactement ce que vous attendez d'un tampon temporaire.

Dans l'exemple ci-dessus, supposons que nous voulions utiliser F32[1024] comme tampon temporaire. Ensuite, nous écrivons le HLO comme ci-dessus, et nous ne lisons jamais l'index tuple 1 de la sortie de l'appel personnalisé.

Tuples dans les appels de processeur personnalisés

Dans le code du processeur, nous avons une fonction do_custom_call(const void** ins, void* out). ins est un tableau comportant un seul élément, qui pointe vers param0. Les sous-tampons de param0 sont accessibles en déréférenceant ce pointeur, et les sous-tampons de output_tuple sont accessibles en déréférenceant out.

Tuples dans les appels personnalisés GPU

Dans le code GPU, nous avons une fonction do_custom_call(..., void** buffers, ...). Dans ce cas, buffers est un tableau hôte de six pointeurs d'appareil, un pour chaque tampon feuille dans l'entrée/sortie. Pour générer la liste plate, nous parcourons les paramètres et la sortie, et pour chacun d'eux, nous effectuons un balayage en précommande de sa forme. Concrètement:

// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1