Cette page a été traduite par l'API Cloud Translation.
Switch to English

Créer une opération

Si vous souhaitez créer une opération qui n'est pas couverte par la bibliothèque TensorFlow existante, nous vous recommandons d'essayer d'abord d'écrire l'opération en Python en tant que composition d'opérations ou de fonctions Python existantes. Si ce n'est pas possible, vous pouvez créer un op C ++ personnalisé. Il y a plusieurs raisons pour lesquelles vous voudrez peut-être créer un op C ++ personnalisé:

  • Il n'est ni facile ni possible d'exprimer votre opération comme une composition d'opérations existantes.
  • Il n'est pas efficace d'exprimer votre opération comme une composition de primitives existantes.
  • Vous voulez fusionner manuellement une composition de primitives qu'un futur compilateur trouverait difficile de fusionner.

Par exemple, imaginons que vous souhaitiez implémenter quelque chose comme «pooling médian», similaire à l'opérateur «MaxPool», mais en calculant les médianes sur des fenêtres glissantes au lieu des valeurs maximales. Faire cela en utilisant une composition d'opérations peut être possible (par exemple, en utilisant ExtractImagePatches et TopK), mais peut ne pas être aussi efficace en termes de performances ou de mémoire qu'une opération native où vous pouvez faire quelque chose de plus intelligent en une seule opération fusionnée. Comme toujours, il vaut la peine d'essayer d'abord d'exprimer ce que vous voulez à l'aide de la composition d'opérateurs, en choisissant d'ajouter une nouvelle opération uniquement si cela s'avère difficile ou inefficace.

Pour intégrer votre opération personnalisée, vous devrez:

  1. Enregistrez la nouvelle opération dans un fichier C ++. L'enregistrement d'opération définit une interface (spécification) pour la fonctionnalité de l'op, qui est indépendante de l'implémentation de l'op. Par exemple, l'enregistrement op définit le nom de l'op et les entrées et sorties de l'op. Il définit également la fonction de forme utilisée pour l'inférence de forme tenseur.
  2. Implémentez l'op en C ++. L'implémentation d'un op est connue sous le nom de noyau, et c'est l'implémentation concrète de la spécification que vous avez enregistrée à l'étape 1. Il peut y avoir plusieurs noyaux pour différents types ou architectures d'entrée / sortie (par exemple, CPU, GPU).
  3. Créez un wrapper Python (facultatif). Ce wrapper est l'API publique utilisée pour créer l'opération en Python. Un wrapper par défaut est généré à partir de l'enregistrement d'opérations, qui peut être utilisé directement ou ajouté à.
  4. Ecrivez une fonction pour calculer les dégradés de l'op (facultatif).
  5. Testez l'op. Nous faisons généralement cela en Python pour plus de commodité, mais vous pouvez également tester l'opération en C ++. Si vous définissez des dégradés, vous pouvez les vérifier avec Python tf.test.compute_gradient_error . Voir relu_op_test.py comme exemple qui teste les fonctions avancées des opérateurs de type Relu et leurs gradients.

Conditions préalables

Définir l'interface op

Vous définissez l'interface d'une opération en l'enregistrant auprès du système TensorFlow. Dans l'enregistrement, vous spécifiez le nom de votre opération, ses entrées (types et noms) et sorties (types et noms), ainsi que les docstrings et tous les attrs que l'opération pourrait nécessiter.

Pour voir comment cela fonctionne, supposons que vous souhaitiez créer une opération qui prend un tenseur de int32 s et génère une copie du tenseur, avec tout sauf le premier élément mis à zéro. Pour ce faire, créez un fichier nommé zero_out.cc . Ensuite, ajoutez un appel à la macro REGISTER_OP qui définit l'interface de votre opération:

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

Cette opération ZeroOut prend un tenseur à to_zero d'entiers 32 bits en entrée et génère un tenseur zeroed d'entiers 32 bits. L'op utilise également une fonction de forme pour s'assurer que le tenseur de sortie a la même forme que le tenseur d'entrée. Par exemple, si l'entrée est un tenseur de forme [10, 20], alors cette fonction de forme spécifie que la forme de sortie est également [10, 20].

Implémenter le noyau pour l'op

Après avoir défini l'interface, fournissez une ou plusieurs implémentations de l'op. Pour créer l'un de ces noyaux, créez une classe qui étend OpKernel et remplace la méthode Compute . La méthode Compute fournit un argument de context de type OpKernelContext* , à partir duquel vous pouvez accéder à des éléments utiles comme les tenseurs d'entrée et de sortie.

Ajoutez votre noyau au fichier que vous avez créé ci-dessus. Le noyau pourrait ressembler à ceci:

#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};

Après avoir implémenté votre noyau, vous l'enregistrez auprès du système TensorFlow. Dans l'enregistrement, vous spécifiez différentes contraintes sous lesquelles ce noyau s'exécutera. Par exemple, vous pouvez avoir un noyau conçu pour les processeurs et un autre pour les GPU.

Pour ce faire pour l'opération ZeroOut , ajoutez ce qui suit à zero_out.cc :

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

Noyaux CPU multi-threads

Pour écrire un noyau de processeur multithread, la fonction Shard dans work_sharder.h peut être utilisée. Cette fonction partage une fonction de calcul entre les threads configurés pour être utilisés pour le thread intra-op (voir intra_op_parallelism_threads dans config.proto ).

Noyaux GPU

Un noyau GPU est implémenté en deux parties: le noyau OpKernel et le noyau CUDA et son code de lancement.

Parfois, l'implémentation d'OpKernel est commune entre un noyau CPU et GPU, par exemple autour de l'inspection des entrées et de l'allocation des sorties. Dans ce cas, une implémentation suggérée consiste à:

  1. Définissez le modèle OpKernel sur le périphérique et le type primitif du tenseur.
  2. Pour effectuer le calcul réel de la sortie, la fonction Compute appelle une structure de foncteur basée sur un modèle.
  3. La spécialisation de ce foncteur pour le CPUDevice est définie dans le même fichier, mais la spécialisation pour le GPUDevice est définie dans un fichier .cu.cc, car il sera compilé avec le compilateur CUDA.

Voici un exemple de mise en œuvre.

// kernel_example.h
#ifndef KERNEL_EXAMPLE_H_
#define KERNEL_EXAMPLE_H_

template <typename Device, typename T>
struct ExampleFunctor {
  void operator()(const Device& d, int size, const T* in, T* out);
};

#if GOOGLE_CUDA
// Partially specialize functor for GpuDevice.
template <typename T>
struct ExampleFunctor<Eigen::GpuDevice, T> {
  void operator()(const Eigen::GpuDevice& d, int size, const T* in, T* out);
};
#endif

#endif KERNEL_EXAMPLE_H_
// kernel_example.cc
#include "kernel_example.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;

// CPU specialization of actual computation.
template <typename T>
struct ExampleFunctor<CPUDevice, T> {
  void operator()(const CPUDevice& d, int size, const T* in, T* out) {
    for (int i = 0; i < size; ++i) {
      out[i] = 2 * in[i];
    }
  }
};

// OpKernel definition.
// template parameter <T> is the datatype of the tensors.
template <typename Device, typename T>
class ExampleOp : public OpKernel {
 public:
  explicit ExampleOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));

    // Do the computation.
    OP_REQUIRES(context, input_tensor.NumElements() <= tensorflow::kint32max,
                errors::InvalidArgument("Too many elements in tensor"));
    ExampleFunctor<Device, T>()(
        context->eigen_device<Device>(),
        static_cast<int>(input_tensor.NumElements()),
        input_tensor.flat<T>().data(),
        output_tensor->flat<T>().data());
  }
};

// Register the CPU kernels.
#define REGISTER_CPU(T)                                          \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("Example").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
      ExampleOp<CPUDevice, T>);
REGISTER_CPU(float);
REGISTER_CPU(int32);

// Register the GPU kernels.
#ifdef GOOGLE_CUDA
#define REGISTER_GPU(T)                                          \
  /* Declare explicit instantiations in kernel_example.cu.cc. */ \
  extern template class ExampleFunctor<GPUDevice, T>;            \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("Example").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
      ExampleOp<GPUDevice, T>);
REGISTER_GPU(float);
REGISTER_GPU(int32);
#endif  // GOOGLE_CUDA
// kernel_example.cu.cc
#ifdef GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "example.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

using namespace tensorflow;

using GPUDevice = Eigen::GpuDevice;

// Define the CUDA kernel.
template <typename T>
__global__ void ExampleCudaKernel(const int size, const T* in, T* out) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
       i += blockDim.x * gridDim.x) {
    out[i] = 2 * __ldg(in + i);
  }
}

// Define the GPU implementation that launches the CUDA kernel.
template <typename T>
void ExampleFunctor<GPUDevice, T>::operator()(
    const GPUDevice& d, int size, const T* in, T* out) {
  // Launch the cuda kernel.
  //
  // See core/util/gpu_kernel_helper.h for example of computing
  // block count and thread_per_block count.
  int block_count = 1024;
  int thread_per_block = 20;
  ExampleCudaKernel<T>
      <<<block_count, thread_per_block, 0, d.stream()>>>(size, in, out);
}

// Explicitly instantiate functors for the types of OpKernels registered.
template struct ExampleFunctor<GPUDevice, float>;
template struct ExampleFunctor<GPUDevice, int32>;

#endif  // GOOGLE_CUDA

Construire la bibliothèque op

Compilez l'opération en utilisant votre compilateur système (installation binaire TensorFlow)

Vous devriez être capable de compiler zero_out.cc avec un compilateur C++ tel que g++ ou clang disponible sur votre système. Le package PIP binaire installe les fichiers d'en-tête et la bibliothèque dont vous avez besoin pour compiler votre opération dans des emplacements spécifiques au système. Cependant, la bibliothèque python TensorFlow fournit la fonction get_include pour obtenir le répertoire d'en-tête, et le répertoire get_lib a un objet partagé à lier. Voici les sorties de ces fonctions sur une machine Ubuntu.

$ python
>>> import tensorflow as tf
>>> tf.sysconfig.get_include()
'/usr/local/lib/python3.6/site-packages/tensorflow/include'
>>> tf.sysconfig.get_lib()
'/usr/local/lib/python3.6/site-packages/tensorflow'

En supposant que g++ installé, voici la séquence de commandes que vous pouvez utiliser pour compiler votre opération dans une bibliothèque dynamique.

TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
g++ -std=c++11 -shared zero_out.cc -o zero_out.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2

Sous macOS, l'indicateur supplémentaire "-undefined dynamic_lookup" est requis lors de la création du fichier .so .

Note sur la version gcc >=5 : gcc utilise le nouvel ABI C ++ depuis la version 5 . Les packages de pip binaires disponibles sur le site Web de TensorFlow sont construits avec gcc4 qui utilise l'ancienne ABI. Si vous compilez votre bibliothèque op avec gcc>=5 , ajoutez -D_GLIBCXX_USE_CXX11_ABI=0 à la ligne de commande pour rendre la bibliothèque compatible avec l'ancien abi.

Compilez l'opération à l'aide de bazel (installation de la source TensorFlow)

Si vous avez installé des sources TensorFlow, vous pouvez utiliser le système de construction de TensorFlow pour compiler votre opération. Placez un fichier BUILD avec la règle de construction Bazel suivante dans le tensorflow/core/user_ops .

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

tf_custom_op_library(
    name = "zero_out.so",
    srcs = ["zero_out.cc"],
)

Exécutez la commande suivante pour générer zero_out.so .

$ bazel build --config opt //tensorflow/core/user_ops:zero_out.so

Utilisez l'op en Python

L'API TensorFlow Python fournit la fonction tf.load_op_library pour charger la bibliothèque dynamique et enregistrer l'opération avec le framework TensorFlow. load_op_library retourne un module Python qui contient les wrappers Python pour l'op et le noyau. Ainsi, une fois que vous avez construit l'op, vous pouvez faire ce qui suit pour l'exécuter à partir de Python:

import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.so')
with tf.Session(''):
  zero_out_module.zero_out([[1, 2], [3, 4]]).eval()

# Prints
array([[1, 0], [0, 0]], dtype=int32)

Gardez à l'esprit que la fonction générée recevra un nom snake_case (pour se conformer à PEP8 ). Donc, si votre op est nommé ZeroOut dans les fichiers C ++, la fonction python s'appellera zero_out .

Pour l'op disponible en fonction régulière import able d'un module Python, il peut être utile d'avoir le load_op_library appel dans un fichier source Python comme suit:

import tensorflow as tf

zero_out_module = tf.load_op_library('./zero_out.so')
zero_out = zero_out_module.zero_out

Vérifiez que l'opération fonctionne

Un bon moyen de vérifier que vous avez implémenté avec succès votre opération est d'écrire un test pour cela. Créez le fichier zero_out_op_test.py avec le contenu:

import tensorflow as tf

class ZeroOutTest(tf.test.TestCase):
  def testZeroOut(self):
    zero_out_module = tf.load_op_library('./zero_out.so')
    with self.test_session():
      result = zero_out_module.zero_out([5, 4, 3, 2, 1])
      self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])

if __name__ == "__main__":
  tf.test.main()

Ensuite, exécutez votre test (en supposant que tensorflow est installé):

$ python zero_out_op_test.py

Intégrez des fonctionnalités avancées dans votre opération

Maintenant que vous savez comment construire une opération et une implémentation basiques (et quelque peu restreintes), nous allons examiner certaines des choses les plus compliquées dont vous aurez généralement besoin pour intégrer votre opération. Ceci comprend:

Contrôles et validation conditionnels

L'exemple ci-dessus supposait que l'op s'appliquait à un tenseur de n'importe quelle forme. Et si cela ne s'appliquait qu'aux vecteurs? Cela signifie ajouter une vérification à l'implémentation OpKernel ci-dessus.

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);

    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
                errors::InvalidArgument("ZeroOut expects a 1-D vector."));
    // ...
  }

Cela affirme que l'entrée est un vecteur et retourne avoir défini le statut InvalidArgument si ce n'est pas le cas. La macro OP_REQUIRES prend trois arguments:

Sinon, si vous voulez tester si un objet Status renvoyé par une fonction est une erreur et si c'est le cas, renvoyez-le, utilisez OP_REQUIRES_OK . Ces deux macros reviennent de la fonction en cas d'erreur.

Inscription aux opérations

Attrs

Les opérations peuvent avoir des attrs, dont les valeurs sont définies lorsque l'opération est ajoutée à un graphique. Ceux-ci sont utilisés pour configurer l'opération, et leurs valeurs sont accessibles à la fois dans l'implémentation du noyau et dans les types d'entrées et de sorties dans l'enregistrement d'opérations. Préférez utiliser une entrée plutôt qu'un attr lorsque cela est possible, car les entrées sont plus flexibles. En effet, les attrs sont des constantes et doivent être définies au moment de la construction du graphe. En revanche, les entrées sont des Tensors dont les valeurs peuvent être dynamiques; c'est-à-dire que les entrées peuvent changer à chaque étape, être définies à l'aide d'un flux, etc. t changer d'une étape à l'autre.

Vous définissez un attr lorsque vous enregistrez l'opération, en spécifiant son nom et son type à l'aide de la méthode Attr , qui attend une spécification de la forme:

<name>: <attr-type-expr>

<name> commence par une lettre et peut être composé de caractères alphanumériques et de traits de soulignement, et <attr-type-expr> est une expression de type de la forme décrite ci-dessous .

Par exemple, si vous souhaitez que l'opération ZeroOut conserve un index spécifié par l'utilisateur, au lieu du seul élément 0, vous pouvez enregistrer l'opération comme suit:

REGISTER_OP("ZeroOut")
    .Attr("preserve_index: int")
    .Input("to_zero: int32")
    .Output("zeroed: int32");

(Notez que l'ensemble des types d'attributs est différent dutf.DType utilisé pour les entrées et les sorties.)

Votre noyau peut alors accéder à cet attr dans son constructeur via le paramètre context :

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
    // Get the index of the value to preserve
    OP_REQUIRES_OK(context,
                   context->GetAttr("preserve_index", &preserve_index_));
    // Check that preserve_index is positive
    OP_REQUIRES(context, preserve_index_ >= 0,
                errors::InvalidArgument("Need preserve_index >= 0, got ",
                                        preserve_index_));
  }
  void Compute(OpKernelContext* context) override {
    // ...
  }
 private:
  int preserve_index_;
};

qui peut ensuite être utilisé dans la méthode Compute :

  void Compute(OpKernelContext* context) override {
    // ...

    // We're using saved attr to validate potentially dynamic input
    // So we check that preserve_index is in range
    OP_REQUIRES(context, preserve_index_ < input.dimension(0),
                errors::InvalidArgument("preserve_index out of range"));

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the requested input value
    output_flat(preserve_index_) = input(preserve_index_);
  }

Attr types

Les types suivants sont pris en charge dans un attr:

  • string : toute séquence d'octets (non requise pour être UTF8).
  • int : un entier signé.
  • float : un nombre à virgule flottante.
  • bool : vrai ou faux.
  • type : une des valeurs (non ref) de DataType .
  • shape : Un TensorShapeProto .
  • list(<type>) : Une liste de <type> , où <type> est l'un des types ci-dessus. Notez que list(list(<type>)) n'est pas valide.

Voir aussi: op_def_builder.cc:FinalizeAttr pour une liste définitive.

Valeurs par défaut et contraintes

Les Attrs peuvent avoir des valeurs par défaut, et certains types d'attrs peuvent avoir des contraintes. Pour définir un attr avec des contraintes, vous pouvez utiliser les <attr-type-expr> s suivants:

{'<string1>', '<string2>'} : la valeur doit être une chaîne qui a la valeur <string1> ou <string2> . Le nom du type, string , est implicite lorsque vous utilisez cette syntaxe. Cela émule une énumération:

REGISTER_OP("EnumExample")
    .Attr("e: {'apple', 'orange'}");

{<type1>, <type2>} : la valeur est de type type et doit être l'une des <type2> <type1> ou <type2> , où <type1> et <type2> sont pris en charge tf.DType . Vous ne spécifiez pas que le type de l'attr est type . Ceci est implicite lorsque vous avez une liste de types dans {...} . Par exemple, dans ce cas, l'attr t est un type qui doit être un int32 , un float ou un bool :

REGISTER_OP("RestrictedTypeExample")
    .Attr("t: {int32, float, bool}");

Il existe des raccourcis pour les contraintes de type courantes:

  • numbertype : type de type limité aux type numériques (non-chaîne et non booléens).
  • realnumbertype : comme numbertype sans types complexes.
  • quantizedtype : comme numbertype mais uniquement les types de nombres quantifiés.

Les listes spécifiques des types autorisés par ceux-ci sont définies par les fonctions (comme NumberTypes() ) dans tensorflow/core/framework/types.h . Dans cet exemple, l'attr t doit être l'un des types numériques:

REGISTER_OP("NumberType")
    .Attr("t: numbertype");

Pour cette opération:

tf.number_type(t=tf.int32)  # Valid
tf.number_type(t=tf.bool)   # Invalid

Les listes peuvent être combinées avec d'autres listes et types uniques. L'opération suivante permet à attr t d'être l'un des types numériques ou le type booléen:

REGISTER_OP("NumberOrBooleanType")
    .Attr("t: {numbertype, bool}");

Pour cette opération:

tf.number_or_boolean_type(t=tf.int32)  # Valid
tf.number_or_boolean_type(t=tf.bool)   # Valid
tf.number_or_boolean_type(t=tf.string) # Invalid

int >= <n> : la valeur doit être un entier dont la valeur est supérieure ou égale à <n> , où <n> est un nombre naturel. Par exemple, l'enregistrement d'opération suivant spécifie que l'attr a doit avoir une valeur d'au moins 2 :

REGISTER_OP("MinIntExample")
    .Attr("a: int >= 2");

list(<type>) >= <n> : Une liste de type <type> dont la longueur est supérieure ou égale à <n> . Par exemple, l'enregistrement d'opération suivant spécifie que l'attr a est une liste de types ( int32 ou float ), et qu'il doit y en avoir au moins 3:

REGISTER_OP("TypeListExample")
    .Attr("a: list({int32, float}) >= 3");

Pour définir une valeur par défaut pour un attr (le rendant facultatif dans le code généré), ajoutez = <default> à la fin, comme dans:

REGISTER_OP("AttrDefaultExample")
    .Attr("i: int = 0");

De plus, une contrainte et une valeur par défaut peuvent être spécifiées:

REGISTER_OP("AttrConstraintAndDefaultExample")
    .Attr("i: int >= 1 = 1");

La syntaxe prise en charge de la valeur par défaut est celle qui serait utilisée dans la représentation proto de la définition GraphDef résultante.

Voici des exemples pour savoir comment spécifier une valeur par défaut pour tous les types:

REGISTER_OP("AttrDefaultExampleForAllTypes")
   .Attr("s: string = 'foo'")
   .Attr("i: int = 0")
   .Attr("f: float = 1.0")
   .Attr("b: bool = true")
   .Attr("ty: type = DT_INT32")
   .Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
   .Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
   .Attr("l_empty: list(int) = []")
   .Attr("l_int: list(int) = [2, 3, 5, 7]");

Notez en particulier que les valeurs de type type utilisenttf.DType .

Polymorphisme

Polymorphisme de type

Pour les opérations qui peuvent prendre différents types en entrée ou produire différents types de sortie, vous pouvez spécifier un attr dans un type d'entrée ou de sortie dans l'enregistrement d'opérations. En règle générale, vous enregistrez alors un OpKernel pour chaque type pris en charge.

Par exemple, si vous souhaitez que l'opération ZeroOut fonctionne sur des float en plus des int32 , votre enregistrement d'opération pourrait ressembler à:

REGISTER_OP("ZeroOut")
    .Attr("T: {float, int32}")
    .Input("to_zero: T")
    .Output("zeroed: T");

Votre enregistrement d'opération spécifie maintenant que le type de l'entrée doit être float , ou int32 , et que sa sortie sera du même type, puisque les deux sont de type T

Appellation

Les entrées, sorties et attrs doivent généralement recevoir des noms snake_case. La seule exception concerne les attrs qui sont utilisés comme type d'entrée ou comme type de sortie. Ces attrs peuvent être déduits lorsque l'op est ajouté au graphique et n'apparaissent donc pas dans la fonction de l'op. Par exemple, cette dernière définition de ZeroOut générera une fonction Python qui ressemble à:

def zero_out(to_zero, name=None):
  """...
  Args:
    to_zero: A `Tensor`. Must be one of the following types:
        `float32`, `int32`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `to_zero`.
  """

Si to_zero est passé un tenseur int32 , alors T est automatiquement mis à int32 (enfin, en fait DT_INT32 ). Ces attrs déduits reçoivent des noms Capitalized ou CamelCase.

Comparez cela avec un op qui a un type attr qui détermine le type de sortie:

REGISTER_OP("StringToNumber")
    .Input("string_tensor: string")
    .Output("output: out_type")
    .Attr("out_type: {float, int32} = DT_FLOAT");
    .Doc(R"doc(
Converts each string in the input Tensor to the specified numeric type.
)doc");

Dans ce cas, l'utilisateur doit spécifier le type de sortie, comme dans le Python généré:

def string_to_number(string_tensor, out_type=None, name=None):
  """Converts each string in the input Tensor to the specified numeric type.

  Args:
    string_tensor: A `Tensor` of type `string`.
    out_type: An optional `tf.DType` from: `tf.float32, tf.int32`.
      Defaults to `tf.float32`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` of type `out_type`.
  """
Exemple de polymorphisme de type
#include "tensorflow/core/framework/op_kernel.h"

class ZeroOutInt32Op : public OpKernel {
  // as before
};

class ZeroOutFloatOp : public OpKernel {
 public:
  explicit ZeroOutFloatOp(OpKernelConstruction* context)
      : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<float>();

    // Create an output tensor
    Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<float>();

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value
    if (N > 0) output_flat(0) = input(0);
  }
};

// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<int32>("T"),
    ZeroOutInt32Op);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutFloatOp);

Pour préserver la compatibilité descendante , vous devez spécifier une valeur par défaut lors de l'ajout d'un attr à une opération existante:

REGISTER_OP("ZeroOut")
  .Attr("T: {float, int32} = DT_INT32")
  .Input("to_zero: T")
  .Output("zeroed: T")

Supposons que vous vouliez ajouter plus de types, disons double :

REGISTER_OP("ZeroOut")
    .Attr("T: {float, double, int32}")
    .Input("to_zero: T")
    .Output("zeroed: T");

Au lieu d'écrire un autre OpKernel avec du code redondant comme ci-dessus, vous pourrez souvent utiliser un modèle C ++ à la place. Vous aurez toujours un enregistrement de noyau (appel REGISTER_KERNEL_BUILDER ) par surcharge.

template <typename T>
class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<T>();

    // Create an output tensor
    Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<T>();

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value
    if (N > 0) output_flat(0) = input(0);
  }
};

// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<int32>("T"),
    ZeroOutOp<int32>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutOp<float>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<double>("T"),
    ZeroOutOp<double>);

Si vous avez plus de quelques surcharges, vous pouvez placer l'enregistrement dans une macro.

#include "tensorflow/core/framework/op_kernel.h"

#define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)

REGISTER_KERNEL(int32);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);

#undef REGISTER_KERNEL

En fonction de la liste des types pour tensorflow/core/framework/register_types.h vous enregistrez le noyau, vous pourrez peut-être utiliser une macro fournie par tensorflow/core/framework/register_types.h :

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"

REGISTER_OP("ZeroOut")
    .Attr("T: realnumbertype")
    .Input("to_zero: T")
    .Output("zeroed: T");

template <typename T>
class ZeroOutOp : public OpKernel { ... };

#define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)

TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);

#undef REGISTER_KERNEL
Liste des entrées et sorties

En plus de pouvoir accepter ou produire différents types, les opérations peuvent consommer ou produire un nombre variable de tenseurs.

Dans l'exemple suivant, la attr T contient une liste de types, et est utilisé comme le type à la fois l'entrée in et la sortie out . L'entrée et la sortie sont des listes de tenseurs de ce type (et le nombre et les types de tenseurs dans la sortie sont les mêmes que l'entrée, puisque les deux sont de type T ).

REGISTER_OP("PolymorphicListExample")
    .Attr("T: list(type)")
    .Input("in: T")
    .Output("out: T");

Vous pouvez également placer des restrictions sur les types pouvant être spécifiés dans la liste. Dans ce cas suivant, l'entrée est une liste de tenseurs float et double . L'op accepte, par exemple, les types d'entrée (float, double, float) et dans ce cas, le type de sortie serait également (float, double, float) .

REGISTER_OP("ListTypeRestrictionExample")
    .Attr("T: list({float, double})")
    .Input("in: T")
    .Output("out: T");

Si vous voulez que tous les tenseurs d'une liste soient du même type, vous pouvez faire quelque chose comme:

REGISTER_OP("IntListInputExample")
    .Attr("N: int")
    .Input("in: N * int32")
    .Output("out: int32");

Cela accepte une liste de tenseurs int32 et utilise un int attr N pour spécifier la longueur de la liste.

Cela peut également être rendu polymorphe . Dans l'exemple suivant, l'entrée est une liste de tenseurs (de longueur "N" ) du même type (mais non spécifié) ( "T" ), et la sortie est un seul tenseur de type correspondant:

REGISTER_OP("SameListInputExample")
    .Attr("N: int")
    .Attr("T: type")
    .Input("in: N * T")
    .Output("out: T");

Par défaut, les listes de tenseurs ont une longueur minimale de 1. Vous pouvez modifier cette valeur par défaut en utilisant une contrainte ">=" sur l'attr correspondant . Dans cet exemple suivant, l'entrée est une liste d'au moins 2 tenseurs int32 :

REGISTER_OP("MinLengthIntListExample")
    .Attr("N: int >= 2")
    .Input("in: N * int32")
    .Output("out: int32");

La même syntaxe fonctionne avec les attrs "list(type)" :

REGISTER_OP("MinimumLengthPolymorphicListExample")
    .Attr("T: list(type) >= 3")
    .Input("in: T")
    .Output("out: T");

Entrées et sorties

Pour résumer ce qui précède, un enregistrement d'opération peut avoir plusieurs entrées et sorties:

REGISTER_OP("MultipleInsAndOuts")
    .Input("y: int32")
    .Input("z: float")
    .Output("a: string")
    .Output("b: int32");

Chaque spécification d'entrée ou de sortie est de la forme:

<name>: <io-type-expr>

<name> commence par une lettre et peut être composé de caractères alphanumériques et de traits de soulignement. <io-type-expr> est l'une des expressions de type suivantes:

  • <type> , où <type> est un type d'entrée pris en charge (par exemple float , int32 , string ). Ceci spécifie un seul tenseur du type donné.

    Voirtf.DType .

    REGISTER_OP("BuiltInTypesExample")
        .Input("integers: int32")
        .Input("complex_numbers: complex64");
    
  • <attr-type> , où <attr-type> est le nom d'un Attr avec type type ou list(type) (avec une restriction de type possible). Cette syntaxe permet des opérations polymorphes .

    REGISTER_OP("PolymorphicSingleInput")
        .Attr("T: type")
        .Input("in: T");
    
    REGISTER_OP("RestrictedPolymorphicSingleInput")
        .Attr("T: {int32, int64}")
        .Input("in: T");
    

    Le référencement d'un attr de type list(type) vous permet d'accepter une suite de tenseurs.

    REGISTER_OP("ArbitraryTensorSequenceExample")
        .Attr("T: list(type)")
        .Input("in: T")
        .Output("out: T");
    
    REGISTER_OP("RestrictedTensorSequenceExample")
        .Attr("T: list({int32, int64})")
        .Input("in: T")
        .Output("out: T");
    

    Notez que le nombre et les types de tenseurs dans la sortie out sont les mêmes que dans l'entrée in , car les deux sont de type T

  • Pour une suite de tenseurs de même type: <number> * <type> , où <number> est le nom d'un Attr de type int . Le <type> peut être soit untf.DType , soit le nom d'un attr avec un type de type . Comme exemple de la première, cette opération accepte une liste de tenseurs int32 :

    REGISTER_OP("Int32SequenceExample")
        .Attr("NumTensors: int")
        .Input("in: NumTensors * int32")
    

    Alors que cette opération accepte une liste de tenseurs de n'importe quel type, tant qu'ils sont tous identiques:

    REGISTER_OP("SameTypeSequenceExample")
        .Attr("NumTensors: int")
        .Attr("T: type")
        .Input("in: NumTensors * T")
    
  • Pour une référence à un tenseur: Ref(<type>) , où <type> est l'un des types précédents.

Tout attr utilisé dans le type d'une entrée sera déduit. Par convention, ces attrs déduits utilisent des noms majuscules (comme T ou N ). Sinon, les entrées, les sorties et les attrs ont des noms comme des paramètres de fonction (par exemple num_outputs ). Pour plus de détails, consultez la section précédente sur la dénomination .

Pour plus de détails, consultez tensorflow/core/framework/op_def_builder.h .

Rétrocompatibilité

Supposons que vous ayez écrit une belle opération personnalisée et que vous la partagiez avec d'autres, afin que vous ayez des clients satisfaits qui utilisent votre opération. Cependant, vous souhaitez apporter des modifications à l'opération d'une manière ou d'une autre.

En général, les modifications apportées aux spécifications existantes enregistrées doivent être rétrocompatibles: la modification de la spécification d'un op ne doit pas casser les tampons de protocole GraphDef sérialisés antérieurs construits à partir de spécifications plus anciennes. Les détails de la compatibilité GraphDef sont décrits ici .

Il existe plusieurs façons de préserver la rétrocompatibilité.

  1. Tout nouveau attrs ajouté à une opération doit avoir des valeurs par défaut définies, et avec cette valeur par défaut, l'opération doit avoir le comportement d'origine. Pour changer une opération de non polymorphe à polymorphe, vous devez donner une valeur par défaut au nouveau type attr pour conserver la signature d'origine par défaut. Par exemple, si votre opération était:

    REGISTER_OP("MyGeneralUnaryOp")
        .Input("in: float")
        .Output("out: float");
    

    vous pouvez le rendre polymorphe de manière rétrocompatible en utilisant:

    REGISTER_OP("MyGeneralUnaryOp")
        .Input("in: T")
        .Output("out: T")
        .Attr("T: numerictype = DT_FLOAT");
    
  2. Vous pouvez sans risque rendre une contrainte sur un attr moins restrictive. Par exemple, vous pouvez passer de {int32, int64} à {int32, int64, float} ou type . Ou vous pouvez passer de {"apple", "orange"} à {"apple", "banana", "orange"} ou string .

  3. Vous pouvez changer les entrées / sorties individuelles en entrées / sorties de liste, à condition que la valeur par défaut du type de liste corresponde à l'ancienne signature.

  4. Vous pouvez ajouter une nouvelle entrée / sortie de liste, si elle est vide par défaut.

  5. Nommez toutes les nouvelles opérations que vous créez, en préfixant les noms d'opérations avec quelque chose d'unique à votre projet. Cela évite que votre opération n'entre en collision avec toute opération qui pourrait être incluse dans les futures versions de TensorFlow.

  6. Planifier à l'avance! Essayez d'anticiper les utilisations futures de l'op. Certaines modifications de signature ne peuvent pas être effectuées de manière compatible (par exemple, en transformant une liste du même type en une liste de types différents).

La liste complète des modifications sûres et non sécurisées se trouve dans tensorflow/core/framework/op_compatibility_test.cc . Si vous ne pouvez pas apporter votre modification à une opération rétrocompatible, créez une nouvelle opération avec un nouveau nom avec la nouvelle sémantique.

Notez également que si ces modifications peuvent maintenir la compatibilité GraphDef , le code Python généré peut changer d'une manière qui n'est pas compatible avec les anciens appelants. L'API Python peut être maintenue compatible en modifiant soigneusement un wrapper Python écrit à la main, en conservant l'ancienne signature, sauf en ajoutant éventuellement de nouveaux arguments optionnels à la fin. Des modifications généralement incompatibles ne peuvent être apportées que lorsque TensorFlow change de versions majeures et doivent se conformer à la sémantique de la version de GraphDef .

Prise en charge du GPU

Vous pouvez implémenter différents OpKernels et en enregistrer un pour le CPU et un autre pour le GPU, tout comme vous pouvez enregistrer des noyaux pour différents types . Il existe plusieurs exemples de noyaux avec prise en charge GPU dans tensorflow/core/kernels/ . Notez que certains noyaux ont une version CPU dans un fichier .cc , une version GPU dans un fichier se terminant par _gpu.cu.cc , et du code partagé en commun dans un fichier .h .

Par exemple, le tf.pad a tout sauf le noyau GPU dans tensorflow/core/kernels/pad_op.cc . Le noyau GPU se trouve dans tensorflow/core/kernels/pad_op_gpu.cu.cc , et le code partagé est une classe tensorflow/core/kernels/pad_op.h un modèle définie dans tensorflow/core/kernels/pad_op.h . Nous organisons le code de cette façon pour deux raisons: cela vous permet de partager du code commun entre les implémentations CPU et GPU, et il place l'implémentation GPU dans un fichier séparé afin qu'elle ne puisse être compilée que par le compilateur GPU.

Une chose à noter, même lorsque la version du noyau GPU du pad est utilisée, il a toujours besoin de son entrée "paddings" dans la mémoire du processeur. Pour marquer que les entrées ou les sorties sont conservées sur le CPU, ajoutez un appel HostMemory() à l'enregistrement du noyau, par exemple:

#define REGISTER_GPU_KERNEL(T)                         \
  REGISTER_KERNEL_BUILDER(Name("Pad")                  \
                              .Device(DEVICE_GPU)      \
                              .TypeConstraint<T>("T")  \
                              .HostMemory("paddings"), \
                          PadOp<GPUDevice, T>)

Compilation du noyau pour le périphérique GPU

Regardez cuda_op_kernel.cu.cc pour un exemple qui utilise un noyau CUDA pour implémenter un op. La tf_custom_op_library accepte un argument gpu_srcs dans lequel la liste des fichiers source contenant les noyaux CUDA (fichiers *.cu.cc ) peut être spécifiée. Pour une utilisation avec une installation binaire de TensorFlow, les noyaux CUDA doivent être compilés avec le compilateur nvcc de NVIDIA. Voici la séquence de commandes que vous pouvez utiliser pour compiler cuda_op_kernel.cu.cc et cuda_op_kernel.cc dans une seule bibliothèque chargeable dynamiquement:

nvcc -std=c++11 -c -o cuda_op_kernel.cu.o cuda_op_kernel.cu.cc \
  ${TF_CFLAGS[@]} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC

g++ -std=c++11 -shared -o cuda_op_kernel.so cuda_op_kernel.cc \
  cuda_op_kernel.cu.o ${TF_CFLAGS[@]} -fPIC -lcudart ${TF_LFLAGS[@]}

cuda_op_kernel.so produit ci-dessus peut être chargé comme d'habitude en Python, en utilisant la fonction tf.load_op_library .

Notez que si vos bibliothèques CUDA ne sont pas installées dans /usr/local/lib64 , vous devrez spécifier le chemin explicitement dans la deuxième commande (g ++) ci-dessus. Par exemple, ajoutez -L /usr/local/cuda-8.0/lib64/ si votre CUDA est installé dans /usr/local/cuda-8.0 .

Implémenter le dégradé en Python

Étant donné un graphique d'opérations, TensorFlow utilise la différenciation automatique (rétropropagation) pour ajouter de nouvelles opérations représentant des dégradés par rapport aux opérations existantes. Pour que la différenciation automatique fonctionne pour les nouvelles opérations, vous devez enregistrer une fonction de gradient qui calcule les gradients par rapport aux entrées des opérations étant donné les gradients par rapport aux sorties des opérations.

Mathématiquement, si un op calcule \(y = f(x)\), l'op de gradient enregistré convertit les gradients \(\partial L/ \partial y\) de perte \(L\) par rapport à \(y\) en gradients \(\partial L/ \partial x\) par rapport à \(x\) via la règle de chaîne:

$$\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial y}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial f}{\partial x}.$$

Dans le cas de ZeroOut , une seule entrée dans l'entrée affecte la sortie, donc le gradient par rapport à l'entrée est un tenseur clairsemé "one hot". Ceci est exprimé comme suit:

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops

@ops.RegisterGradient("ZeroOut")
def _zero_out_grad(op, grad):
  """The gradients for `zero_out`.

  Args:
    op: The `zero_out` `Operation` that we are differentiating, which we can use
      to find the inputs and outputs of the original op.
    grad: Gradient with respect to the output of the `zero_out` op.

  Returns:
    Gradients with respect to the input of `zero_out`.
  """
  to_zero = op.inputs[0]
  shape = array_ops.shape(to_zero)
  index = array_ops.zeros_like(shape)
  first_grad = array_ops.reshape(grad, [-1])[0]
  to_zero_grad = sparse_ops.sparse_to_dense([index], shape, first_grad, 0)
  return [to_zero_grad]  # List of one Tensor, since we have one input

Détails sur l'enregistrement des fonctions de gradient avec tf.RegisterGradient :

  • Pour un op avec une sortie, la fonction gradient prendra un tf.Operation , op et un tf.Tensor grad et construira de nouvelles ops à partir des tenseurs op.inputs[i] , op.outputs[i] et grad . Des informations sur les attrs peuvent être trouvées via tf.Operation.get_attr .

  • Si l'op a plusieurs sorties, la fonction de dégradé prendra op et grads , où grads est une liste de dégradés par rapport à chaque sortie. Le résultat de la fonction de dégradé doit être une liste d'objets Tensor représentant les dégradés par rapport à chaque entrée.

  • S'il n'y a pas de gradient bien défini pour certaines entrées, comme pour les entrées entières utilisées comme indices, le gradient retourné correspondant doit être None . Par exemple, pour un op prenant un tenseur à virgule flottante x et un indice entier i , la fonction de gradient return [x_grad, None] .

  • S'il n'y a pas du tout de dégradé significatif pour l'opération, vous n'aurez souvent pas à enregistrer de dégradé, et tant que le dégradé de l'opération n'est jamais nécessaire, tout ira bien. Dans certains cas, un op n'a pas de gradient bien défini mais peut être impliqué dans le calcul du gradient. Ici, vous pouvez utiliser ops.NotDifferentiable pour propager automatiquement les zéros vers l'arrière.

Notez qu'au moment où la fonction de gradient est appelée, seul le graphique de flux de données des opérations est disponible, et non les données du tenseur elles-mêmes. Ainsi, tous les calculs doivent être effectués à l'aide d'autres opérations tensorflow, à exécuter au moment de l'exécution du graphe.

Fonctions de forme en C ++

L'API TensorFlow possède une fonctionnalité appelée «inférence de forme» qui fournit des informations sur les formes des tenseurs sans avoir à exécuter le graphe. L'inférence de forme est prise en charge par les «fonctions de forme» qui sont enregistrées pour chaque type d'opération dans la déclaration C ++ REGISTER_OP et remplissent deux rôles: affirmer que les formes des entrées sont compatibles lors de la construction du graphe et spécifier les formes des sorties.

Les fonctions de forme sont définies comme des opérations sur la classe shape_inference::InferenceContext . Par exemple, dans la fonction de forme pour ZeroOut:

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

c->set_output(0, c->input(0)); déclare que la forme de la première sortie doit être définie sur la forme de la première entrée. Si la sortie est sélectionnée par son index comme dans l'exemple ci-dessus, le deuxième paramètre de set_output doit être un objet ShapeHandle . Vous pouvez créer un objet ShapeHandle vide par son constructeur par défaut. L'objet ShapeHandle pour une entrée avec un index idx peut être obtenu par c->input(idx) .

Il existe un certain nombre de fonctions de forme communes qui s'appliquent à de nombreuses opérations, telles que shape_inference::UnchangedShape qui peut être trouvée dans common_shape_fns.h et utilisée comme suit:

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn(::tensorflow::shape_inference::UnchangedShape);

Une fonction de forme peut également contraindre la forme d'une entrée. Pour la version de ZeroOut avec une contrainte de forme vectorielle , la fonction de forme serait la suivante:

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      ::tensorflow::shape_inference::ShapeHandle input;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input));
      c->set_output(0, input);
      return Status::OK();
    });

L'appel WithRank valide que la forme d'entrée c->input(0) a une forme avec exactement une dimension (ou si la forme d'entrée est inconnue, la forme de sortie sera un vecteur avec une dimension inconnue).

Si votre opération est polymorphe avec plusieurs entrées , vous pouvez utiliser des membres d' InferenceContext pour déterminer le nombre de formes à vérifier, et Merge pour valider que les formes sont toutes compatibles ( InferenceContext::GetAttr , InferenceContext::GetAttr aux attributs qui indiquent les longueurs, avec InferenceContext::GetAttr , qui donne accès aux attributs de l'op).

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      ::tensorflow::shape_inference::ShapeHandle input;
      ::tensorflow::shape_inference::ShapeHandle output;
      for (size_t i = 0; i < c->num_inputs(); ++i) {
        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &input));
        TF_RETURN_IF_ERROR(c->Merge(output, input, &output));
      }
      c->set_output(0, output);
      return Status::OK();
    });

Étant donné que l'inférence de forme est une caractéristique facultative et que les formes des tenseurs peuvent varier de manière dynamique, les fonctions de forme doivent être robustes aux informations de forme incomplètes pour l'une des entrées. La méthode Merge dans InferenceContext permet à l'appelant d'affirmer que deux formes sont identiques, même si l'une d'entre elles ou les deux ne disposent pas d'informations complètes. Les fonctions de forme sont définies pour toutes les opérations principales de TensorFlow et fournissent de nombreux exemples d'utilisation différents.

La classe InferenceContext a un certain nombre de fonctions qui peuvent être utilisées pour définir des manipulations de fonction de forme. Par exemple, vous pouvez valider qu'une dimension particulière a une valeur très spécifique en utilisant InferenceContext::Dim et InferenceContext::WithValue ; vous pouvez spécifier qu'une dimension de sortie est la somme / le produit de deux dimensions d'entrée en utilisant InferenceContext::Add et InferenceContext::Multiply . Consultez la classe InferenceContext pour toutes les différentes manipulations de forme que vous pouvez spécifier. L'exemple suivant définit la forme de la première sortie sur (n, 3), où la première entrée a la forme (n, ...)

.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
    c->set_output(0, c->Matrix(c->Dim(c->input(0), 0), 3));
    return Status::OK();
});

Si vous avez une fonction de forme complexe, vous devez envisager d'ajouter un test pour vérifier que diverses combinaisons de formes d'entrée produisent les combinaisons de formes de sortie attendues. Vous pouvez voir des exemples d'écriture de ces tests dans certains de nos tests d'opérations de base . (La syntaxe de INFER_OK et INFER_ERROR est un peu cryptique, mais essayez d'être compact dans la représentation des spécifications de forme d'entrée et de sortie dans les tests. Pour l'instant, consultez les commentaires environnants dans ces tests pour avoir une idée de la spécification de la chaîne de forme).

Créez un package pip pour votre opération personnalisée

Pour créer un package pip pour votre opération, consultez l'exemple tensorflow / custom-op . Ce guide montre comment créer des opérations personnalisées à partir du package pip TensorFlow au lieu de créer TensorFlow à partir de la source.