Crear una operación

Si desea crear una operación que no esté cubierta por la biblioteca de TensorFlow existente, le recomendamos que primero intente escribir la operación en Python como una composición de las operaciones o funciones de Python existentes. Si eso no es posible, puede crear una operación de C ++ personalizada. Hay varias razones por las que es posible que desee crear una operación de C ++ personalizada:

  • No es fácil ni posible expresar su operación como una composición de operaciones existentes.
  • No es eficiente expresar su operación como una composición de primitivas existentes.
  • Desea fusionar manualmente una composición de primitivas que un futuro compilador encontraría difícil de fusionar.

Por ejemplo, imagine que desea implementar algo como "agrupación de medianas", similar al operador "MaxPool", pero calculando medianas sobre ventanas deslizantes en lugar de valores máximos. Hacer esto usando una composición de operaciones puede ser posible (por ejemplo, usando ExtractImagePatches y TopK), pero puede no ser tan eficiente en rendimiento o memoria como una operación nativa donde puede hacer algo más inteligente en una sola operación fusionada. Como siempre, por lo general, primero vale la pena intentar expresar lo que desea mediante la composición del operador, y solo optar por agregar una nueva operación si resulta difícil o ineficaz.

Para incorporar su operación personalizada, deberá:

  1. Registre la nueva operación en un archivo C ++. El registro de la operación define una interfaz (especificación) para la funcionalidad de la operación, que es independiente de la implementación de la operación. Por ejemplo, el registro de la operación define el nombre de la operación y las entradas y salidas de la operación. También define la función de forma que se utiliza para la inferencia de forma de tensor.
  2. Implemente la operación en C ++. La implementación de una operación se conoce como kernel, y es la implementación concreta de la especificación que registró en el Paso 1. Puede haber múltiples kernels para diferentes tipos de entrada / salida o arquitecturas (por ejemplo, CPU, GPU).
  3. Crea una envoltura de Python (opcional). Este contenedor es la API pública que se usa para crear la operación en Python. Se genera un contenedor predeterminado a partir del registro de operaciones, que se puede usar directamente o agregar.
  4. Escribe una función para calcular gradientes para la operación (opcional).
  5. Prueba la op. Por lo general, hacemos esto en Python por conveniencia, pero también puede probar la operación en C ++. Si define gradientes, puede verificarlos con Python tf.test.compute_gradient_error . Vea relu_op_test.py como un ejemplo que prueba las funciones de avance de operadores similares a Relu y sus gradientes.

Prerrequisitos

Definir la interfaz de operaciones

Usted define la interfaz de una operación registrándola con el sistema TensorFlow. En el registro, especifica el nombre de su operación, sus entradas (tipos y nombres) y salidas (tipos y nombres), así como las cadenas de documentación y cualquier atributo que la operación pueda requerir.

Para ver cómo funciona esto, suponga que le gustaría crear una int32 que tome un tensor de int32 sy int32 una copia del tensor, con todos los elementos menos el primero establecidos en cero. Para hacer esto, cree un archivo llamado zero_out.cc . Luego agregue una llamada a la macro REGISTER_OP que define la interfaz para su operación:

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

Esta ZeroOut toma un tensor to_zero de to_zero de 32 bits como entrada y genera un tensor con zeroed de enteros de 32 bits. La operación también usa una función de forma para garantizar que el tensor de salida tenga la misma forma que el tensor de entrada. Por ejemplo, si la entrada es un tensor de forma [10, 20], entonces esta función de forma especifica que la forma de salida también es [10, 20].

Implementar el kernel para la operación

Después de definir la interfaz, proporcione una o más implementaciones del op. Para crear uno de estos núcleos, cree una clase que amplíe OpKernel y anule el método Compute . El método Compute proporciona un argumento de context de tipo OpKernelContext* , desde el cual puede acceder a cosas útiles como los tensores de entrada y salida.

Agregue su kernel al archivo que creó anteriormente. El kernel podría verse así:

#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};

Después de implementar su kernel, lo registra con el sistema TensorFlow. En el registro, especifica diferentes restricciones bajo las cuales se ejecutará este kernel. Por ejemplo, puede tener un kernel hecho para CPU y otro separado para GPU.

Para hacer esto para la ZeroOut , agregue lo siguiente a zero_out.cc :

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

Núcleos de CPU de subprocesos múltiples

Para escribir un núcleo de CPU work_sharder.h se puede utilizar la función Shard en work_sharder.h . Esta función fragmenta una función de cálculo entre los subprocesos configurados para usarse en subprocesos intraoperativos (consulte intra_op_parallelism_threads en config.proto ).

Núcleos de GPU

Un kernel de GPU se implementa en dos partes: el OpKernel y el kernel CUDA y su código de lanzamiento.

A veces, la implementación de OpKernel es común entre un núcleo de CPU y GPU, como inspeccionar entradas y asignar salidas. En ese caso, una implementación sugerida es:

  1. Defina la plantilla OpKernel en el dispositivo y el tipo primitivo del tensor.
  2. Para hacer el cálculo real de la salida, la función Compute llama a una estructura de functor con plantilla.
  3. La especialización de ese functor para el CPUDevice se define en el mismo archivo, pero la especialización para GPUDevice se define en un archivo .cu.cc, ya que se compilará con el compilador CUDA.

Aquí hay una implementación de ejemplo.

// kernel_example.h
#ifndef KERNEL_EXAMPLE_H_
#define KERNEL_EXAMPLE_H_

#include <unsupported/Eigen/CXX11/Tensor>

template <typename Device, typename T>
struct ExampleFunctor {
  void operator()(const Device& d, int size, const T* in, T* out);
};

#if GOOGLE_CUDA
// Partially specialize functor for GpuDevice.
template <typename T>
struct ExampleFunctor<Eigen::GpuDevice, T> {
  void operator()(const Eigen::GpuDevice& d, int size, const T* in, T* out);
};
#endif

#endif KERNEL_EXAMPLE_H_
// kernel_example.cc
#include "kernel_example.h"

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;

REGISTER_OP("Example")
    .Attr("T: numbertype")
    .Input("input: T")
    .Output("input_times_two: T")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

// CPU specialization of actual computation.
template <typename T>
struct ExampleFunctor<CPUDevice, T> {
  void operator()(const CPUDevice& d, int size, const T* in, T* out) {
    for (int i = 0; i < size; ++i) {
      out[i] = 2 * in[i];
    }
  }
};

// OpKernel definition.
// template parameter <T> is the datatype of the tensors.
template <typename Device, typename T>
class ExampleOp : public OpKernel {
 public:
  explicit ExampleOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));

    // Do the computation.
    OP_REQUIRES(context, input_tensor.NumElements() <= tensorflow::kint32max,
                errors::InvalidArgument("Too many elements in tensor"));
    ExampleFunctor<Device, T>()(
        context->eigen_device<Device>(),
        static_cast<int>(input_tensor.NumElements()),
        input_tensor.flat<T>().data(),
        output_tensor->flat<T>().data());
  }
};

// Register the CPU kernels.
#define REGISTER_CPU(T)                                          \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("Example").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
      ExampleOp<CPUDevice, T>);
REGISTER_CPU(float);
REGISTER_CPU(int32);

// Register the GPU kernels.
#ifdef GOOGLE_CUDA
#define REGISTER_GPU(T)                                          \
  /* Declare explicit instantiations in kernel_example.cu.cc. */ \
  extern template class ExampleFunctor<GPUDevice, T>;            \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("Example").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
      ExampleOp<GPUDevice, T>);
REGISTER_GPU(float);
REGISTER_GPU(int32);
#endif  // GOOGLE_CUDA
// kernel_example.cu.cc
#ifdef GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "kernel_example.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

using namespace tensorflow;

using GPUDevice = Eigen::GpuDevice;

// Define the CUDA kernel.
template <typename T>
__global__ void ExampleCudaKernel(const int size, const T* in, T* out) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
       i += blockDim.x * gridDim.x) {
    out[i] = 2 * __ldg(in + i);
  }
}

// Define the GPU implementation that launches the CUDA kernel.
template <typename T>
void ExampleFunctor<GPUDevice, T>::operator()(
    const GPUDevice& d, int size, const T* in, T* out) {
  // Launch the cuda kernel.
  //
  // See core/util/gpu_kernel_helper.h for example of computing
  // block count and thread_per_block count.
  int block_count = 1024;
  int thread_per_block = 20;
  ExampleCudaKernel<T>
      <<<block_count, thread_per_block, 0, d.stream()>>>(size, in, out);
}

// Explicitly instantiate functors for the types of OpKernels registered.
template struct ExampleFunctor<GPUDevice, float>;
template struct ExampleFunctor<GPUDevice, int32>;

#endif  // GOOGLE_CUDA

Construye la biblioteca de operaciones

Compile la operación usando el compilador de su sistema (instalación binaria de TensorFlow)

Debería poder compilar zero_out.cc con un compilador de C++ como g++ o clang disponible en su sistema. El paquete binario PIP instala los archivos de encabezado y la biblioteca que necesita para compilar su operación en ubicaciones que son específicas del sistema. Sin embargo, la biblioteca de Python de TensorFlow proporciona la función get_include para obtener el directorio de encabezado, y el directorio get_lib tiene un objeto compartido para vincular. Aquí están los resultados de estas funciones en una máquina Ubuntu.

$ python
>>> import tensorflow as tf
>>> tf.sysconfig.get_include()
'/usr/local/lib/python3.6/site-packages/tensorflow/include'
>>> tf.sysconfig.get_lib()
'/usr/local/lib/python3.6/site-packages/tensorflow'

Suponiendo que tiene g++ instalado, aquí está la secuencia de comandos que puede usar para compilar su operación en una biblioteca dinámica.

TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
g++ -std=c++11 -shared zero_out.cc -o zero_out.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2

En MacOS, se requiere que el indicador adicional "dynamic_lookup -undefined" en la construcción de la .so archivo.

Nota sobre la versión de gcc >=5 : gcc usa la nueva ABI de C ++ desde la versión 5 . Los paquetes de pip binarios disponibles en el sitio web de TensorFlow se gcc4 con gcc4 que usa la ABI anterior. Si compila su biblioteca de operaciones con gcc>=5 , agregue -D_GLIBCXX_USE_CXX11_ABI=0 a la línea de comando para hacer que la biblioteca sea compatible con el abi anterior.

Compile la operación usando bazel (instalación de fuente de TensorFlow)

Si tiene fuentes de TensorFlow instaladas, puede hacer uso del sistema de compilación de TensorFlow para compilar su operación. Coloque un archivo BUILD con la siguiente regla de compilación de Bazel en el tensorflow/core/user_ops .

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

tf_custom_op_library(
    name = "zero_out.so",
    srcs = ["zero_out.cc"],
)

Ejecute el siguiente comando para construir zero_out.so .

$ bazel build --config opt //tensorflow/core/user_ops:zero_out.so

Para compilar el Example operación, con el kernel de CUDA, es necesario utilizar los gpu_srcs parámetro de tf_custom_op_library . Coloque un archivo BUILD con la siguiente regla de compilación de Bazel en una nueva carpeta dentro del tensorflow/core/user_ops (por ejemplo, "example_gpu").

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

tf_custom_op_library(
    # kernel_example.cc  kernel_example.cu.cc  kernel_example.h
    name = "kernel_example.so",
    srcs = ["kernel_example.h", "kernel_example.cc"],
    gpu_srcs = ["kernel_example.cu.cc", "kernel_example.h"],
)

Ejecute el siguiente comando para compilar kernel_example.so .

$ bazel build --config opt //tensorflow/core/user_ops/example_gpu:kernel_example.so

Usa la operación en Python

La API de Python de TensorFlow proporciona la función tf.load_op_library para cargar la biblioteca dinámica y registrar la operación con el marco de trabajo de TensorFlow. load_op_library devuelve un módulo de Python que contiene los envoltorios de Python para la operación y el kernel. Por lo tanto, una vez que haya creado la operación, puede hacer lo siguiente para ejecutarla desde Python:

import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.so')
print(zero_out_module.zero_out([[1, 2], [3, 4]]).numpy())

# Prints
array([[1, 0], [0, 0]], dtype=int32)

Tenga en cuenta que a la función generada se le asignará un nombre snake_case (para cumplir con PEP8 ). Entonces, si su ZeroOut se llama ZeroOut en los archivos C ++, la función de Python se llamará zero_out .

Para que la operación esté disponible como una función normal que se pueda import desde un módulo de Python, puede ser útil tener la llamada load_op_library en un archivo fuente de Python de la siguiente manera:

import tensorflow as tf

zero_out_module = tf.load_op_library('./zero_out.so')
zero_out = zero_out_module.zero_out

Verifica que la operación funcione

Una buena forma de verificar que ha implementado con éxito su operación es escribir una prueba para ella. Cree el archivo zero_out_op_test.py con el contenido:

import tensorflow as tf

class ZeroOutTest(tf.test.TestCase):
  def testZeroOut(self):
    zero_out_module = tf.load_op_library('./zero_out.so')
    with self.test_session():
      result = zero_out_module.zero_out([5, 4, 3, 2, 1])
      self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])

if __name__ == "__main__":
  tf.test.main()

Luego ejecute su prueba (suponiendo que tenga instalado tensorflow):

$ python zero_out_op_test.py

Construya funciones avanzadas en su operación

Ahora que sabe cómo construir una operación e implementación básica (y algo restringida), veremos algunas de las cosas más complicadas que normalmente necesitará construir en su operación. Esto incluye:

Verificaciones y validaciones condicionales

El ejemplo anterior asumió que la op se aplica a un tensor de cualquier forma. ¿Y si solo se aplicara a los vectores? Eso significa agregar una verificación a la implementación de OpKernel anterior.

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);

    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
                errors::InvalidArgument("ZeroOut expects a 1-D vector."));
    // ...
  }

Esto afirma que la entrada es un vector, y regresa habiendo establecido el estado InvalidArgument si no lo es. La macro OP_REQUIRES toma tres argumentos:

Alternativamente, si desea probar si un objeto de Status devuelto por alguna función es un error y, de ser así, devolverlo, use OP_REQUIRES_OK . Ambas macros regresan de la función en caso de error.

Registro de operaciones

Atributos

Las operaciones pueden tener atributos, cuyos valores se establecen cuando la operación se agrega a un gráfico. Estos se utilizan para configurar la operación, y se puede acceder a sus valores tanto dentro de la implementación del kernel como en los tipos de entradas y salidas en el registro de la operación. Prefiera usar una entrada en lugar de un atributo cuando sea posible, ya que las entradas son más flexibles. Esto se debe a que los atributos son constantes y deben definirse en el momento de la construcción del gráfico. Por el contrario, las entradas son tensores cuyos valores pueden ser dinámicos; es decir, las entradas pueden cambiar en cada paso, configurarse usando un feed, etc. Los atributos se usan para cosas que no se pueden hacer con entradas: cualquier configuración que afecte la firma (número o tipo de entradas o salidas) o que pueda ' t cambiar de un paso a otro.

Usted define un atributo cuando registra la operación, especificando su nombre y tipo usando el método Attr , que espera una especificación del formulario:

<name>: <attr-type-expr>

donde <name> comienza con una letra y puede estar compuesto por caracteres alfanuméricos y guiones bajos, y <attr-type-expr> es una expresión de tipo de la forma que se describe a continuación .

Por ejemplo, si desea que la ZeroOut conserve un índice especificado por el usuario, en lugar de solo el elemento 0, puede registrar la operación de la siguiente manera:

REGISTER_OP("ZeroOut")
    .Attr("preserve_index: int")
    .Input("to_zero: int32")
    .Output("zeroed: int32");

(Tenga en cuenta que el conjunto de tipos de atributos es diferente deltf.DType utilizado para entradas y salidas).

Su kernel puede acceder a este atributo en su constructor a través del parámetro de context :

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
    // Get the index of the value to preserve
    OP_REQUIRES_OK(context,
                   context->GetAttr("preserve_index", &preserve_index_));
    // Check that preserve_index is positive
    OP_REQUIRES(context, preserve_index_ >= 0,
                errors::InvalidArgument("Need preserve_index >= 0, got ",
                                        preserve_index_));
  }
  void Compute(OpKernelContext* context) override {
    // ...
  }
 private:
  int preserve_index_;
};

que luego se puede usar en el método Compute :

  void Compute(OpKernelContext* context) override {
    // ...

    // We're using saved attr to validate potentially dynamic input
    // So we check that preserve_index is in range
    OP_REQUIRES(context, preserve_index_ < input.dimension(0),
                errors::InvalidArgument("preserve_index out of range"));

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the requested input value
    output_flat(preserve_index_) = input(preserve_index_);
  }

Tipos de atributos

Los siguientes tipos se admiten en un atributo:

  • string : cualquier secuencia de bytes (no es necesario que sea UTF8).
  • int : un entero con signo.
  • float : un número de punto flotante.
  • bool : Verdadero o falso.
  • type : uno de los valores (no ref) de DataType .
  • shape : un TensorShapeProto .
  • list(<type>) : una lista de <type> , donde <type> es uno de los tipos anteriores. Tenga en cuenta que la list(list(<type>)) no es válida.

Consulte también: op_def_builder.cc:FinalizeAttr para obtener una lista definitiva.

Restricciones y valores predeterminados

Los atributos pueden tener valores predeterminados y algunos tipos de atributos pueden tener restricciones. Para definir un atributo con restricciones, puede utilizar los siguientes <attr-type-expr> s:

{'<string1>', '<string2>'} : el valor debe ser una cadena que tenga el valor <string1> o <string2> . El nombre del tipo, string , está implícito cuando usa esta sintaxis. Esto emula una enumeración:

REGISTER_OP("EnumExample")
    .Attr("e: {'apple', 'orange'}");

{<type1>, <type2>} : el valor es de tipo type y debe ser uno de <type1> o <type2> , donde <type1> y <type2> son compatibles con tf.DType . No especificas que el tipo de atributo es type . Esto está implícito cuando tiene una lista de tipos en {...} . Por ejemplo, en este caso el attr t es un tipo que debe ser un int32 , un float o un bool :

REGISTER_OP("RestrictedTypeExample")
    .Attr("t: {int32, float, bool}");

Hay atajos para restricciones de tipo comunes:

  • numbertype : tipo de type restringido a los type numéricos (sin cadena ni bool).
  • realnumbertype : Como numbertype sin tipos complejos.
  • quantizedtype : como el tipo numbertype pero solo los tipos numéricos cuantificados.

Las listas específicas de tipos permitidos por estos están definidas por las funciones (como NumberTypes() ) en tensorflow/core/framework/types.h . En este ejemplo el attr t debe ser uno de los tipos numéricos:

REGISTER_OP("NumberType")
    .Attr("t: numbertype");

Para esta operación:

tf.number_type(t=tf.int32)  # Valid
tf.number_type(t=tf.bool)   # Invalid

Las listas se pueden combinar con otras listas y tipos individuales. La siguiente operación permite que attr t sea ​​cualquiera de los tipos numéricos o el tipo bool:

REGISTER_OP("NumberOrBooleanType")
    .Attr("t: {numbertype, bool}");

Para esta operación:

tf.number_or_boolean_type(t=tf.int32)  # Valid
tf.number_or_boolean_type(t=tf.bool)   # Valid
tf.number_or_boolean_type(t=tf.string) # Invalid

int >= <n> : El valor debe ser un int cuyo valor sea mayor o igual que <n> , donde <n> es un número natural. Por ejemplo, el siguiente registro de operaciones especifica que el atributo a debe tener un valor de al menos 2 :

REGISTER_OP("MinIntExample")
    .Attr("a: int >= 2");

list(<type>) >= <n> : Una lista de tipo <type> cuya longitud es mayor o igual que <n> . Por ejemplo, el siguiente registro de operaciones especifica que el atributo a es una lista de tipos (ya sea int32 o float ), y que debe haber al menos 3 de ellos:

REGISTER_OP("TypeListExample")
    .Attr("a: list({int32, float}) >= 3");

Para establecer un valor predeterminado para un atributo (haciéndolo opcional en el código generado), agregue = <default> al final, como en:

REGISTER_OP("AttrDefaultExample")
    .Attr("i: int = 0");

Además, se pueden especificar tanto una restricción como un valor predeterminado:

REGISTER_OP("AttrConstraintAndDefaultExample")
    .Attr("i: int >= 1 = 1");

La sintaxis admitida del valor predeterminado es la que se usaría en la representación proto de la definición de GraphDef resultante.

A continuación, se muestran ejemplos de cómo especificar un valor predeterminado para todos los tipos:

REGISTER_OP("AttrDefaultExampleForAllTypes")
   .Attr("s: string = 'foo'")
   .Attr("i: int = 0")
   .Attr("f: float = 1.0")
   .Attr("b: bool = true")
   .Attr("ty: type = DT_INT32")
   .Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
   .Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
   .Attr("l_empty: list(int) = []")
   .Attr("l_int: list(int) = [2, 3, 5, 7]");

Tenga en cuenta, en particular, que los valores del tipo type usantf.DType .

Polimorfismo

Tipo de polimorfismo

Para operaciones que pueden tomar diferentes tipos como entrada o producir diferentes tipos de salida, puede especificar un atributo en un tipo de entrada o salida en el registro de operación. Por lo general, registraría un OpKernel para cada tipo admitido.

Por ejemplo, si desea que la ZeroOut funcione en float s además de int32 s, su registro de int32 podría verse así:

REGISTER_OP("ZeroOut")
    .Attr("T: {float, int32}")
    .Input("to_zero: T")
    .Output("zeroed: T");

Su registro de operaciones ahora especifica que el tipo de entrada debe ser float , o int32 , y que su salida será del mismo tipo, ya que ambos tienen el tipo T

Nombrar

Las entradas, salidas y atributos generalmente deben recibir nombres de caso de serpiente. La única excepción son los atributos que se utilizan como tipo de entrada o como tipo de salida. Esos atributos se pueden inferir cuando la operación se agrega al gráfico y, por lo tanto, no aparecen en la función de la operación. Por ejemplo, esta última definición de ZeroOut generará una función de Python que se ve así:

def zero_out(to_zero, name=None):
  """...
  Args:
    to_zero: A `Tensor`. Must be one of the following types:
        `float32`, `int32`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `to_zero`.
  """

Si a to_zero se le pasa un tensor int32 , entonces T se establece automáticamente en int32 (bueno, en realidad DT_INT32 ). A esos atributos inferidos se les asignan nombres en mayúsculas o CamelCase.

Compare esto con una operación que tiene un tipo attr que determina el tipo de salida:

REGISTER_OP("StringToNumber")
    .Input("string_tensor: string")
    .Output("output: out_type")
    .Attr("out_type: {float, int32} = DT_FLOAT");
    .Doc(R"doc(
Converts each string in the input Tensor to the specified numeric type.
)doc");

En este caso, el usuario debe especificar el tipo de salida, como en el Python generado:

def string_to_number(string_tensor, out_type=None, name=None):
  """Converts each string in the input Tensor to the specified numeric type.

  Args:
    string_tensor: A `Tensor` of type `string`.
    out_type: An optional `tf.DType` from: `tf.float32, tf.int32`.
      Defaults to `tf.float32`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` of type `out_type`.
  """
Ejemplo de polimorfismo de tipo
#include "tensorflow/core/framework/op_kernel.h"

class ZeroOutInt32Op : public OpKernel {
  // as before
};

class ZeroOutFloatOp : public OpKernel {
 public:
  explicit ZeroOutFloatOp(OpKernelConstruction* context)
      : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<float>();

    // Create an output tensor
    Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<float>();

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value
    if (N > 0) output_flat(0) = input(0);
  }
};

// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<int32>("T"),
    ZeroOutInt32Op);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutFloatOp);

Para preservar la compatibilidad con versiones anteriores , debe especificar un valor predeterminado al agregar un atributo a una operación existente:

REGISTER_OP("ZeroOut")
  .Attr("T: {float, int32} = DT_INT32")
  .Input("to_zero: T")
  .Output("zeroed: T")

Digamos que desea agregar más tipos, digamos double :

REGISTER_OP("ZeroOut")
    .Attr("T: {float, double, int32}")
    .Input("to_zero: T")
    .Output("zeroed: T");

En lugar de escribir otro OpKernel con código redundante como el anterior, a menudo podrá utilizar una plantilla de C ++. Aún tendrá un registro de kernel (llamada REGISTER_KERNEL_BUILDER ) por sobrecarga.

template <typename T>
class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<T>();

    // Create an output tensor
    Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<T>();

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value
    if (N > 0) output_flat(0) = input(0);
  }
};

// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<int32>("T"),
    ZeroOutOp<int32>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutOp<float>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<double>("T"),
    ZeroOutOp<double>);

Si tiene más de un par de sobrecargas, puede poner el registro en una macro.

#include "tensorflow/core/framework/op_kernel.h"

#define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)

REGISTER_KERNEL(int32);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);

#undef REGISTER_KERNEL

Dependiendo de la lista de tipos para los que está registrando el kernel, es posible que pueda usar una macro proporcionada por tensorflow/core/framework/register_types.h :

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"

REGISTER_OP("ZeroOut")
    .Attr("T: realnumbertype")
    .Input("to_zero: T")
    .Output("zeroed: T");

template <typename T>
class ZeroOutOp : public OpKernel { ... };

#define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)

TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);

#undef REGISTER_KERNEL
Lista de entradas y salidas

Además de poder aceptar o producir diferentes tipos, las operaciones pueden consumir o producir un número variable de tensores.

En el siguiente ejemplo, el attr T mantiene una lista de tipos, y se utiliza como el tipo de tanto la entrada in y la salida out . La entrada y la salida son listas de tensores de ese tipo (y el número y los tipos de tensores en la salida son los mismos que los de la entrada, ya que ambos tienen el tipo T ).

REGISTER_OP("PolymorphicListExample")
    .Attr("T: list(type)")
    .Input("in: T")
    .Output("out: T");

También puede imponer restricciones sobre los tipos que se pueden especificar en la lista. En el siguiente caso, la entrada es una lista de tensores float y double . La operación acepta, por ejemplo, tipos de entrada (float, double, float) y en ese caso el tipo de salida también sería (float, double, float) .

REGISTER_OP("ListTypeRestrictionExample")
    .Attr("T: list({float, double})")
    .Input("in: T")
    .Output("out: T");

Si desea que todos los tensores de una lista sean del mismo tipo, puede hacer algo como:

REGISTER_OP("IntListInputExample")
    .Attr("N: int")
    .Input("in: N * int32")
    .Output("out: int32");

Esto acepta una lista de tensores int32 y usa un atributo int N para especificar la longitud de la lista.

Esto también se puede hacer de tipo polimórfico . En el siguiente ejemplo, la entrada es una lista de tensores (con longitud "N" ) del mismo tipo (pero no especificado) ( "T" ), y la salida es un solo tensor de tipo coincidente:

REGISTER_OP("SameListInputExample")
    .Attr("N: int")
    .Attr("T: type")
    .Input("in: N * T")
    .Output("out: T");

Por defecto, las listas de tensores tienen una longitud mínima de 1. Puede cambiar ese valor predeterminado usando una restricción ">=" en el atributo correspondiente . En el siguiente ejemplo, la entrada es una lista de al menos 2 tensores int32 :

REGISTER_OP("MinLengthIntListExample")
    .Attr("N: int >= 2")
    .Input("in: N * int32")
    .Output("out: int32");

La misma sintaxis funciona con atributos de "list(type)" :

REGISTER_OP("MinimumLengthPolymorphicListExample")
    .Attr("T: list(type) >= 3")
    .Input("in: T")
    .Output("out: T");

Entradas y salidas

Para resumir lo anterior, un registro de operación puede tener múltiples entradas y salidas:

REGISTER_OP("MultipleInsAndOuts")
    .Input("y: int32")
    .Input("z: float")
    .Output("a: string")
    .Output("b: int32");

Cada especificación de entrada o salida tiene la forma:

<name>: <io-type-expr>

donde <name> comienza con una letra y puede estar compuesto por caracteres alfanuméricos y guiones bajos. <io-type-expr> es una de las siguientes expresiones de tipo:

  • <type> , donde <type> es un tipo de entrada admitido (por ejemplo, float , int32 , string ). Esto especifica un solo tensor del tipo dado.

    Consultetf.DType .

    REGISTER_OP("BuiltInTypesExample")
        .Input("integers: int32")
        .Input("complex_numbers: complex64");
    
  • <attr-type> , donde <attr-type> es el nombre de un Attr con type type o list(type) (con una posible restricción de tipo). Esta sintaxis permite operaciones polimórficas .

    REGISTER_OP("PolymorphicSingleInput")
        .Attr("T: type")
        .Input("in: T");
    
    REGISTER_OP("RestrictedPolymorphicSingleInput")
        .Attr("T: {int32, int64}")
        .Input("in: T");
    

    Hacer referencia a un atributo de list(type) de tipos list(type) permite aceptar una secuencia de tensores.

    REGISTER_OP("ArbitraryTensorSequenceExample")
        .Attr("T: list(type)")
        .Input("in: T")
        .Output("out: T");
    
    REGISTER_OP("RestrictedTensorSequenceExample")
        .Attr("T: list({int32, int64})")
        .Input("in: T")
        .Output("out: T");
    

    Tenga en cuenta que el número y los tipos de tensores en la salida de out es el mismo que en la entrada in entrada, ya que ambos son de tipo T

  • Para una secuencia de tensores con el mismo tipo: <number> * <type> , donde <number> es el nombre de un Attr con tipo int . El <type> puede ser untf.DType o el nombre de un atributo con type type . Como ejemplo del primero, esta int32 acepta una lista de tensores int32 :

    REGISTER_OP("Int32SequenceExample")
        .Attr("NumTensors: int")
        .Input("in: NumTensors * int32")
    

    Considerando que esta operación acepta una lista de tensores de cualquier tipo, siempre que sean todos iguales:

    REGISTER_OP("SameTypeSequenceExample")
        .Attr("NumTensors: int")
        .Attr("T: type")
        .Input("in: NumTensors * T")
    
  • Para una referencia a un tensor: Ref(<type>) , donde <type> es uno de los tipos anteriores.

Se deducirá cualquier atributo utilizado en el tipo de entrada. Por convención, los atributos inferidos usan nombres en mayúsculas (como T o N ). De lo contrario, las entradas, salidas y atributos tienen nombres como parámetros de función (por ejemplo, num_outputs ). Para obtener más detalles, consulte la sección anterior sobre nombres .

Para obtener más detalles, consulte tensorflow/core/framework/op_def_builder.h .

Compatibilidad al revés

Supongamos que ha escrito una operación agradable y personalizada y la ha compartido con otros, por lo que tiene clientes felices utilizando su operación. Sin embargo, le gustaría realizar cambios en la operación de alguna manera.

En general, los cambios a las especificaciones registradas existentes deben ser compatibles con versiones anteriores: cambiar la especificación de una GraphDef no debe romper los GraphDef protocolo GraphDef serializados GraphDef construidos a partir de especificaciones anteriores. Los detalles de la compatibilidad de GraphDef se describen aquí .

Hay varias formas de preservar la compatibilidad con versiones anteriores.

  1. Cualquier atributo nuevo agregado a una operación debe tener valores predeterminados definidos, y con ese valor predeterminado, la operación debe tener el comportamiento original. Para cambiar una operación de no polimórfica a polimórfica, debe asignar un valor predeterminado al nuevo tipo de atributo para conservar la firma original de forma predeterminada. Por ejemplo, si su operación fue:

    REGISTER_OP("MyGeneralUnaryOp")
        .Input("in: float")
        .Output("out: float");
    

    puede hacerlo polimórfico de una manera compatible con versiones anteriores usando:

    REGISTER_OP("MyGeneralUnaryOp")
        .Input("in: T")
        .Output("out: T")
        .Attr("T: numerictype = DT_FLOAT");
    
  2. Puede hacer que una restricción en un atributo sea menos restrictiva. Por ejemplo, puede cambiar de {int32, int64} a {int32, int64, float} o type . O puede cambiar de {"apple", "orange"} a {"apple", "banana", "orange"} o string .

  3. Puede cambiar entradas / salidas individuales en entradas / salidas de lista, siempre que el valor predeterminado para el tipo de lista coincida con la firma anterior.

  4. Puede agregar una nueva entrada / salida de lista, si por defecto está vacía.

  5. Ponga un espacio de nombres en las nuevas operaciones que cree, prefijando los nombres de las operaciones con algo único para su proyecto. Esto evita que su operación choque con cualquier operación que pueda incluirse en versiones futuras de TensorFlow.

  6. ¡Planifique con anticipación! Intente anticipar los usos futuros de la operación. Algunos cambios de firma no se pueden realizar de una manera compatible (por ejemplo, hacer una lista del mismo tipo en una lista de diferentes tipos).

La lista completa de cambios seguros e inseguros se puede encontrar en tensorflow/core/framework/op_compatibility_test.cc . Si no puede hacer que su cambio a una operación sea compatible con versiones anteriores, cree una nueva operación con un nuevo nombre con la nueva semántica.

También tenga en cuenta que si bien estos cambios pueden mantener la compatibilidad con GraphDef , el código Python generado puede cambiar de una manera que no es compatible con los antiguos llamadores. La API de Python puede mantenerse compatible mediante cambios cuidadosos en un contenedor de Python escrito a mano, manteniendo la firma anterior, excepto posiblemente agregando nuevos argumentos opcionales al final. Los cambios generalmente incompatibles solo se pueden realizar cuando TensorFlow cambia las versiones principales y deben cumplir con la semántica de la versión GraphDef .

Soporte de GPU

Puede implementar diferentes OpKernels y registrar uno para CPU y otro para GPU, al igual que puede registrar kernels para diferentes tipos . Hay varios ejemplos de kernels con soporte de GPU en tensorflow/core/kernels/ . Observe que algunos kernels tienen una versión de CPU en un archivo .cc , una versión de GPU en un archivo que termina en _gpu.cu.cc y algún código compartido en común en un archivo .h .

Por ejemplo, tf.pad tiene todo menos el kernel de la GPU en tensorflow/core/kernels/pad_op.cc . El kernel de la GPU está en tensorflow/core/kernels/pad_op_gpu.cu.cc , y el código compartido es una clase con plantilla definida en tensorflow/core/kernels/pad_op.h . Organizamos el código de esta manera por dos razones: le permite compartir código común entre las implementaciones de CPU y GPU, y coloca la implementación de GPU en un archivo separado para que solo pueda compilarlo el compilador de GPU.

Una cosa a tener en cuenta, incluso cuando se usa la versión del kernel de la GPU del pad , todavía necesita su entrada "paddings" en la memoria de la CPU. Para marcar que las entradas o salidas se mantienen en la CPU, agregue una llamada HostMemory() al registro del kernel, por ejemplo:

#define REGISTER_GPU_KERNEL(T)                         \
  REGISTER_KERNEL_BUILDER(Name("Pad")                  \
                              .Device(DEVICE_GPU)      \
                              .TypeConstraint<T>("T")  \
                              .HostMemory("paddings"), \
                          PadOp<GPUDevice, T>)

Compilando el kernel para el dispositivo GPU

Mire cuda_op_kernel.cu.cc para ver un ejemplo que usa un kernel CUDA para implementar una operación. tf_custom_op_library acepta un argumento gpu_srcs en el que se puede especificar la lista de archivos fuente que contienen los núcleos CUDA (archivos *.cu.cc ). Para usar con una instalación binaria de TensorFlow, los núcleos CUDA deben compilarse con el compilador nvcc de NVIDIA. Aquí está la secuencia de comandos que puede usar para compilar cuda_op_kernel.cu.cc y cuda_op_kernel.cc en una única biblioteca cargable dinámicamente:

nvcc -std=c++11 -c -o cuda_op_kernel.cu.o cuda_op_kernel.cu.cc \
  ${TF_CFLAGS[@]} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC

g++ -std=c++11 -shared -o cuda_op_kernel.so cuda_op_kernel.cc \
  cuda_op_kernel.cu.o ${TF_CFLAGS[@]} -fPIC -lcudart ${TF_LFLAGS[@]}

cuda_op_kernel.so producido anteriormente se puede cargar como de costumbre en Python, usando la función tf.load_op_library .

Tenga en cuenta que si sus bibliotecas CUDA no están instaladas en /usr/local/lib64 , deberá especificar la ruta explícitamente en el segundo comando (g ++) anterior. Por ejemplo, agregue -L /usr/local/cuda-8.0/lib64/ si su CUDA está instalado en /usr/local/cuda-8.0 .

Implementar el degradado en Python

Dado un gráfico de operaciones, TensorFlow usa la diferenciación automática (retropropagación) para agregar nuevas operaciones que representan gradientes con respecto a las operaciones existentes. Para que la diferenciación automática funcione para nuevas operaciones, debe registrar una función de gradiente que calcule los gradientes con respecto a las entradas de las operaciones, dados los gradientes con respecto a las salidas de las operaciones.

Matemáticamente, si una operación calcula \(y = f(x)\), la operación de gradiente registrada convierte los gradientes \(\partial L/ \partial y\) de pérdida \(L\) con respecto a \(y\) en gradientes \(\partial L/ \partial x\) con respecto a \(x\) mediante la regla de la cadena:

$$\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial y}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial f}{\partial x}.$$

En el caso de ZeroOut , solo una entrada en la entrada afecta la salida, por lo que el gradiente con respecto a la entrada es un tensor escaso "one hot". Esto se expresa de la siguiente manera:

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops

@ops.RegisterGradient("ZeroOut")
def _zero_out_grad(op, grad):
  """The gradients for `zero_out`.

  Args:
    op: The `zero_out` `Operation` that we are differentiating, which we can use
      to find the inputs and outputs of the original op.
    grad: Gradient with respect to the output of the `zero_out` op.

  Returns:
    Gradients with respect to the input of `zero_out`.
  """
  to_zero = op.inputs[0]
  shape = array_ops.shape(to_zero)
  index = array_ops.zeros_like(shape)
  first_grad = array_ops.reshape(grad, [-1])[0]
  to_zero_grad = sparse_ops.sparse_to_dense([index], shape, first_grad, 0)
  return [to_zero_grad]  # List of one Tensor, since we have one input

Detalles sobre el registro de funciones de gradiente con tf.RegisterGradient :

  • Para una operación con una salida, la función de gradiente tomará tf.Operation , op y tf.Tensor grad y construirá nuevas operaciones a partir de los tensores op.inputs[i] , op.outputs[i] y grad . Se puede encontrar información sobre cualquier atributo a través de tf.Operation.get_attr .

  • Si la operación tiene múltiples salidas, la función de gradiente tomará op y grads , donde grads es una lista de gradientes con respecto a cada salida. El resultado de la función de gradiente debe ser una lista de objetos Tensor que representen los gradientes con respecto a cada entrada.

  • Si no hay un gradiente bien definido para alguna entrada, como para las entradas enteras utilizadas como índices, el gradiente devuelto correspondiente debe ser None . Por ejemplo, para una operación que toma un tensor de coma flotante x y un índice entero i , la función de gradiente return [x_grad, None] .

  • Si no hay ningún gradiente significativo para la operación, a menudo no tendrá que registrar ningún gradiente, y siempre que nunca se necesite el gradiente de la operación, estará bien. En algunos casos, una operación no tiene un gradiente bien definido, pero puede participar en el cálculo del gradiente. Aquí puede usar ops.NotDifferentiable para propagar ceros automáticamente hacia atrás.

Tenga en cuenta que en el momento en que se llama a la función de gradiente, solo está disponible el gráfico de flujo de datos de las operaciones, no los datos del tensor en sí. Por lo tanto, todos los cálculos deben realizarse utilizando otras operaciones de flujo tensorial, que se ejecutarán en el momento de la ejecución del gráfico.

Funciones de forma en C ++

La API de TensorFlow tiene una función llamada "inferencia de forma" que proporciona información sobre las formas de los tensores sin tener que ejecutar el gráfico. La inferencia de formas está respaldada por "funciones de formas" que se registran para cada tipo de operación en la declaración REGISTER_OP C ++ y realizan dos funciones: afirmar que las formas de las entradas son compatibles durante la construcción del gráfico y especificar las formas para las salidas.

Las funciones de forma se definen como operaciones en la clase shape_inference::InferenceContext . Por ejemplo, en la función de forma para ZeroOut:

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

c->set_output(0, c->input(0)); declara que la forma de la primera salida debe establecerse en la forma de la primera entrada. Si la salida se selecciona por su índice como en el ejemplo anterior, el segundo parámetro de set_output debería ser un objeto ShapeHandle . Puede crear un objeto ShapeHandle vacío mediante su constructor predeterminado. El objeto ShapeHandle para una entrada con índice idx se puede obtener mediante c->input(idx) .

Hay una serie de funciones de forma comunes que se aplican a muchas operaciones, como shape_inference::UnchangedShape que se puede encontrar en common_shape_fns.hy se utiliza de la siguiente manera:

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn(::tensorflow::shape_inference::UnchangedShape);

Una función de forma también puede restringir la forma de una entrada. Para la versión de ZeroOut con una restricción de forma vectorial , la función de forma sería la siguiente:

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      ::tensorflow::shape_inference::ShapeHandle input;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input));
      c->set_output(0, input);
      return Status::OK();
    });

La llamada WithRank valida que la forma de entrada c->input(0) tiene una forma con exactamente una dimensión (o si la forma de entrada es desconocida, la forma de salida será un vector con una dimensión desconocida).

Si su operación es polimórfica con múltiples entradas , puede usar miembros de InferenceContext para determinar la cantidad de formas a verificar, y Merge para validar que todas las formas son compatibles (alternativamente, acceda a los atributos que indican las longitudes, con InferenceContext::GetAttr , que proporciona acceso a los atributos de la operación).

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      ::tensorflow::shape_inference::ShapeHandle input;
      ::tensorflow::shape_inference::ShapeHandle output;
      for (size_t i = 0; i < c->num_inputs(); ++i) {
        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &input));
        TF_RETURN_IF_ERROR(c->Merge(output, input, &output));
      }
      c->set_output(0, output);
      return Status::OK();
    });

Dado que la inferencia de forma es una característica opcional, y las formas de los tensores pueden variar dinámicamente, las funciones de forma deben ser robustas a la información de forma incompleta para cualquiera de las entradas. El método Merge en InferenceContext permite a la persona que llama afirmar que dos formas son iguales, incluso si una o ambas no tienen información completa. Las funciones de forma se definen para todas las operaciones principales de TensorFlow y proporcionan muchos ejemplos de uso diferentes.

La clase InferenceContext tiene una serie de funciones que se pueden usar para definir manipulaciones de funciones de forma. Por ejemplo, puede validar que una dimensión en particular tiene un valor muy específico usando InferenceContext::Dim e InferenceContext::WithValue ; puede especificar que una dimensión de salida es la suma / producto de dos dimensiones de entrada usando InferenceContext::Add e InferenceContext::Multiply . Consulte la clase InferenceContext para conocer todas las diversas manipulaciones de formas que puede especificar. El siguiente ejemplo establece la forma de la primera salida en (n, 3), donde la primera entrada tiene forma (n, ...)

.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
    c->set_output(0, c->Matrix(c->Dim(c->input(0), 0), 3));
    return Status::OK();
});

Si tiene una función de forma complicada, debería considerar agregar una prueba para validar que varias combinaciones de formas de entrada producen las combinaciones de formas de salida esperadas. Puede ver ejemplos de cómo escribir estas pruebas en algunas de nuestras pruebas de operaciones principales . (La sintaxis de INFER_OK e INFER_ERROR es un poco críptica, pero intente ser compacta al representar las especificaciones de forma de entrada y salida en las pruebas. Por ahora, consulte los comentarios que rodean esas pruebas para tener una idea de la especificación de la cadena de forma).

Cree un paquete de pip para su operación personalizada

Para construir un paquete pip para su operación , vea el ejemplo de tensorflow / custom-op . Esta guía muestra cómo compilar operaciones personalizadas desde el paquete pip de TensorFlow en lugar de compilar TensorFlow desde la fuente.