crear una operación

Si desea crear una operación que no esté cubierta por la biblioteca TensorFlow existente, le recomendamos que primero intente escribir la operación en Python como una composición de operaciones o funciones de Python existentes. Si eso no es posible, puede crear una operación C++ personalizada. Hay varias razones por las que es posible que desee crear una operación C++ personalizada:

  • No es fácil ni posible expresar su operación como una composición de operaciones existentes.
  • No es eficiente expresar su operación como una composición de primitivas existentes.
  • Desea fusionar a mano una composición de primitivas que a un futuro compilador le resultaría difícil fusionar.

Por ejemplo, imagine que desea implementar algo como "agrupación de medianas", similar al operador "MaxPool", pero calculando medianas sobre ventanas deslizantes en lugar de valores máximos. Hacer esto usando una composición de operaciones puede ser posible (por ejemplo, usando ExtractImagePatches y TopK), pero puede no ser tan eficiente en términos de rendimiento o memoria como una operación nativa donde puede hacer algo más inteligente en una sola operación fusionada. Como siempre, por lo general, primero vale la pena tratar de expresar lo que desea mediante la composición de operadores, eligiendo solo agregar una nueva operación si resulta difícil o ineficiente.

Para incorporar su operación personalizada, deberá:

  1. Registre la nueva operación en un archivo C++. El registro de operación define una interfaz (especificación) para la funcionalidad de la operación, que es independiente de la implementación de la operación. Por ejemplo, el registro de operación define el nombre de la operación y las entradas y salidas de la operación. También define la función de forma que se utiliza para la inferencia de forma de tensor.
  2. Implemente la operación en C++. La implementación de una operación se conoce como kernel y es la implementación concreta de la especificación que registró en el Paso 1. Puede haber múltiples kernels para diferentes arquitecturas o tipos de entrada/salida (por ejemplo, CPU, GPU).
  3. Cree un contenedor de Python (opcional). Este contenedor es la API pública que se usa para crear la operación en Python. Se genera un envoltorio predeterminado a partir del registro de operaciones, que se puede usar directamente o agregar.
  4. Escriba una función para calcular los gradientes de la operación (opcional).
  5. Pruebe la op. Usualmente hacemos esto en Python por conveniencia, pero también puede probar la operación en C++. Si define gradientes, puede verificarlos con Python tf.test.compute_gradient_error . Consulte relu_op_test.py como un ejemplo que prueba las funciones directas de operadores similares a Relu y sus gradientes.

requisitos previos

Definir la interfaz de operación

La interfaz de una operación se define registrándola en el sistema TensorFlow. En el registro, especifica el nombre de su operación, sus entradas (tipos y nombres) y salidas (tipos y nombres), así como las cadenas de documentación y cualquier atributo que pueda requerir la operación.

Para ver cómo funciona esto, suponga que desea crear una operación que tome un tensor de int32 s y genere una copia del tensor, con todos menos el primer elemento configurados en cero. Para hacer esto, cree un archivo llamado zero_out.cc . Luego agregue una llamada a la macro REGISTER_OP que define la interfaz para su operación:

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

Esta operación ZeroOut toma un tensor to_zero de enteros de 32 bits como entrada y genera un tensor zeroed de enteros de 32 bits. La operación también usa una función de forma para garantizar que el tensor de salida tenga la misma forma que el tensor de entrada. Por ejemplo, si la entrada es un tensor de forma [10, 20], esta función de forma especifica que la forma de salida también es [10, 20].

Implementar el kernel para la operación

Después de definir la interfaz, proporcione una o más implementaciones de la op. Para crear uno de estos kernels, cree una clase que amplíe OpKernel y anule el método Compute . El método Compute proporciona un argumento context de tipo OpKernelContext* , desde el cual puede acceder a cosas útiles como los tensores de entrada y salida.

Agregue su núcleo al archivo que creó anteriormente. El kernel podría verse así:

#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};

Después de implementar su kernel, lo registra con el sistema TensorFlow. En el registro, especifica diferentes restricciones bajo las cuales se ejecutará este núcleo. Por ejemplo, puede tener un kernel creado para CPU y otro separado para GPU.

Para hacer esto para la operación ZeroOut , agregue lo siguiente a zero_out.cc :

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

Núcleos de CPU de subprocesos múltiples

Para escribir un kernel de CPU de subprocesos múltiples, se puede usar la función Shard en work_sharder.h . Esta función fragmenta una función de cómputo en los subprocesos configurados para usarse para subprocesos dentro de la operación (consulte intra_op_parallelism_threads en config.proto ).

Núcleos de GPU

Un núcleo de GPU se implementa en dos partes: OpKernel y el núcleo CUDA y su código de lanzamiento.

A veces, la implementación de OpKernel es común entre un kernel de CPU y GPU, como en la inspección de entradas y la asignación de salidas. En ese caso, una implementación sugerida es:

  1. Defina la plantilla de OpKernel en el dispositivo y el tipo primitivo del tensor.
  2. Para realizar el cálculo real de la salida, la función Calcular llama a una estructura de funtor con plantilla.
  3. La especialización de ese funtor para el CPUDevice se define en el mismo archivo, pero la especialización para el GPUDevice se define en un archivo .cu.cc, ya que se compilará con el compilador CUDA.

Aquí hay una implementación de ejemplo.

// kernel_example.h
#ifndef KERNEL_EXAMPLE_H_
#define KERNEL_EXAMPLE_H_

#include <unsupported/Eigen/CXX11/Tensor>

template <typename Device, typename T>
struct ExampleFunctor {
  void operator()(const Device& d, int size, const T* in, T* out);
};

#if GOOGLE_CUDA
// Partially specialize functor for GpuDevice.
template <typename T>
struct ExampleFunctor<Eigen::GpuDevice, T> {
  void operator()(const Eigen::GpuDevice& d, int size, const T* in, T* out);
};
#endif

#endif KERNEL_EXAMPLE_H_
// kernel_example.cc
#include "kernel_example.h"

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;

REGISTER_OP("Example")
    .Attr("T: numbertype")
    .Input("input: T")
    .Output("input_times_two: T")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

// CPU specialization of actual computation.
template <typename T>
struct ExampleFunctor<CPUDevice, T> {
  void operator()(const CPUDevice& d, int size, const T* in, T* out) {
    for (int i = 0; i < size; ++i) {
      out[i] = 2 * in[i];
    }
  }
};

// OpKernel definition.
// template parameter <T> is the datatype of the tensors.
template <typename Device, typename T>
class ExampleOp : public OpKernel {
 public:
  explicit ExampleOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));

    // Do the computation.
    OP_REQUIRES(context, input_tensor.NumElements() <= tensorflow::kint32max,
                errors::InvalidArgument("Too many elements in tensor"));
    ExampleFunctor<Device, T>()(
        context->eigen_device<Device>(),
        static_cast<int>(input_tensor.NumElements()),
        input_tensor.flat<T>().data(),
        output_tensor->flat<T>().data());
  }
};

// Register the CPU kernels.
#define REGISTER_CPU(T)                                          \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("Example").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
      ExampleOp<CPUDevice, T>);
REGISTER_CPU(float);
REGISTER_CPU(int32);

// Register the GPU kernels.
#ifdef GOOGLE_CUDA
#define REGISTER_GPU(T)                                          \
  /* Declare explicit instantiations in kernel_example.cu.cc. */ \
  extern template class ExampleFunctor<GPUDevice, T>;            \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("Example").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
      ExampleOp<GPUDevice, T>);
REGISTER_GPU(float);
REGISTER_GPU(int32);
#endif  // GOOGLE_CUDA
// kernel_example.cu.cc
#ifdef GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "kernel_example.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

using namespace tensorflow;

using GPUDevice = Eigen::GpuDevice;

// Define the CUDA kernel.
template <typename T>
__global__ void ExampleCudaKernel(const int size, const T* in, T* out) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
       i += blockDim.x * gridDim.x) {
    out[i] = 2 * __ldg(in + i);
  }
}

// Define the GPU implementation that launches the CUDA kernel.
template <typename T>
void ExampleFunctor<GPUDevice, T>::operator()(
    const GPUDevice& d, int size, const T* in, T* out) {
  // Launch the cuda kernel.
  //
  // See core/util/gpu_kernel_helper.h for example of computing
  // block count and thread_per_block count.
  int block_count = 1024;
  int thread_per_block = 20;
  ExampleCudaKernel<T>
      <<<block_count, thread_per_block, 0, d.stream()>>>(size, in, out);
}

// Explicitly instantiate functors for the types of OpKernels registered.
template struct ExampleFunctor<GPUDevice, float>;
template struct ExampleFunctor<GPUDevice, int32>;

#endif  // GOOGLE_CUDA

Cree la biblioteca de operaciones

Compile la operación usando el compilador de su sistema (instalación binaria de TensorFlow)

Debería poder compilar zero_out.cc con un compilador C++ como g++ o clang disponible en su sistema. El paquete PIP binario instala los archivos de encabezado y la biblioteca que necesita para compilar su operación en ubicaciones que son específicas del sistema. Sin embargo, la biblioteca de Python de TensorFlow proporciona la función get_include para obtener el directorio de encabezado, y el directorio get_lib tiene un objeto compartido con el que vincularse. Aquí están los resultados de estas funciones en una máquina Ubuntu.

$ python
>>> import tensorflow as tf
>>> tf.sysconfig.get_include()
'/usr/local/lib/python3.6/site-packages/tensorflow/include'
>>> tf.sysconfig.get_lib()
'/usr/local/lib/python3.6/site-packages/tensorflow'

Suponiendo que tiene g++ instalado, aquí está la secuencia de comandos que puede usar para compilar su operación en una biblioteca dinámica.

TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
g++ -std=c++14 -shared zero_out.cc -o zero_out.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2

En macOS, se requiere el indicador adicional "-undefined dynamic_lookup" al crear el archivo .so .

Nota sobre la versión gcc >=5 : gcc usa la nueva ABI de C++ desde la versión 5 . TensorFlow 2.8 y versiones anteriores se crearon con gcc4 que usa la ABI anterior. Si usa estas versiones de TensorFlow y está tratando de compilar su biblioteca de operaciones con gcc>=5 , agregue -D_GLIBCXX_USE_CXX11_ABI=0 a la línea de comando para que la biblioteca sea compatible con la ABI anterior. Los paquetes TensorFlow 2.9+ son compatibles con la ABI más nueva de forma predeterminada.

Compile la operación usando bazel (instalación fuente de TensorFlow)

Si tiene fuentes de TensorFlow instaladas, puede utilizar el sistema de compilación de TensorFlow para compilar su operación. Coloque un archivo BUILD con la siguiente regla de compilación de Bazel en el directorio tensorflow/core/user_ops .

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

tf_custom_op_library(
    name = "zero_out.so",
    srcs = ["zero_out.cc"],
)

Ejecute el siguiente comando para compilar zero_out.so .

$ bazel build --config opt //tensorflow/core/user_ops:zero_out.so

Para compilar la operación Example , con CUDA Kernel, debe usar el parámetro gpu_srcs de tf_custom_op_library . Coloque un archivo BUILD con la siguiente regla de compilación de Bazel en una nueva carpeta dentro del directorio tensorflow/core/user_ops (por ejemplo, "example_gpu").

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

tf_custom_op_library(
    # kernel_example.cc  kernel_example.cu.cc  kernel_example.h
    name = "kernel_example.so",
    srcs = ["kernel_example.h", "kernel_example.cc"],
    gpu_srcs = ["kernel_example.cu.cc", "kernel_example.h"],
)

Ejecute el siguiente comando para compilar kernel_example.so .

$ bazel build --config opt //tensorflow/core/user_ops/example_gpu:kernel_example.so

Usa la op en Python

La API de Python de TensorFlow proporciona la función tf.load_op_library para cargar la biblioteca dinámica y registrar la operación con el marco de TensorFlow. load_op_library devuelve un módulo de Python que contiene los contenedores de Python para la operación y el kernel. Por lo tanto, una vez que haya creado la operación, puede hacer lo siguiente para ejecutarla desde Python:

import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.so')
print(zero_out_module.zero_out([[1, 2], [3, 4]]).numpy())

# Prints
array([[1, 0], [0, 0]], dtype=int32)

Tenga en cuenta que la función generada recibirá un nombre de caja de serpiente (para cumplir con PEP8 ). Entonces, si su operación se llama ZeroOut en los archivos de C++, la función de python se llamará zero_out .

Para que la operación esté disponible como una función import regular desde un módulo de Python, puede ser útil tener la llamada load_op_library en un archivo fuente de Python de la siguiente manera:

import tensorflow as tf

zero_out_module = tf.load_op_library('./zero_out.so')
zero_out = zero_out_module.zero_out

Verifique que la operación funcione

Una buena manera de verificar que ha implementado con éxito su operación es escribir una prueba para ello. Cree el archivo zero_out_op_test.py con el contenido:

import tensorflow as tf

class ZeroOutTest(tf.test.TestCase):
  def testZeroOut(self):
    zero_out_module = tf.load_op_library('./zero_out.so')
    with self.test_session():
      result = zero_out_module.zero_out([5, 4, 3, 2, 1])
      self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])

if __name__ == "__main__":
  tf.test.main()

Luego ejecute su prueba (suponiendo que tenga instalado tensorflow):

$ python zero_out_op_test.py

Incorpore funciones avanzadas en su operación

Ahora que sabe cómo crear una operación e implementación básicas (y un tanto restringidas), veremos algunas de las cosas más complicadas que normalmente necesitará incorporar en su operación. Esto incluye:

Comprobaciones condicionales y validación

El ejemplo anterior supuso que op se aplicaba a un tensor de cualquier forma. ¿Qué pasaría si solo se aplicara a los vectores? Eso significa agregar un cheque a la implementación de OpKernel anterior.

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);

    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
                errors::InvalidArgument("ZeroOut expects a 1-D vector."));
    // ...
  }

Esto afirma que la entrada es un vector y devuelve haber establecido el estado InvalidArgument si no lo es. La macro OP_REQUIRES toma tres argumentos:

Alternativamente, si desea probar si un objeto Status devuelto por alguna función es un error y, de ser así, devolverlo, use OP_REQUIRES_OK . Ambas macros regresan de la función en caso de error.

registro de operaciones

atributos

Las operaciones pueden tener atributos, cuyos valores se establecen cuando la operación se agrega a un gráfico. Estos se utilizan para configurar la operación y se puede acceder a sus valores tanto dentro de la implementación del kernel como en los tipos de entradas y salidas en el registro de la operación. Prefiere usar una entrada en lugar de un atributo cuando sea posible, ya que las entradas son más flexibles. Esto se debe a que attrs son constantes y deben definirse en el momento de la construcción del gráfico. En contraste, las entradas son Tensores cuyos valores pueden ser dinámicos; es decir, las entradas pueden cambiar cada paso, configurarse mediante un feed, etc. Los atributos se utilizan para cosas que no se pueden hacer con las entradas: cualquier configuración que afecte la firma (número o tipo de entradas o salidas) o que no se pueda t cambiar de paso a paso.

Usted define un atributo cuando registra la operación, especificando su nombre y tipo usando el método Attr , que espera una especificación del formulario:

<name>: <attr-type-expr>

donde <name> comienza con una letra y puede estar compuesto por caracteres alfanuméricos y guiones bajos, y <attr-type-expr> es una expresión de tipo de la forma que se describe a continuación .

Por ejemplo, si desea que la operación ZeroOut conserve un índice especificado por el usuario, en lugar de solo el elemento 0, puede registrar la operación de la siguiente manera:

REGISTER_OP("ZeroOut")
    .Attr("preserve_index: int")
    .Input("to_zero: int32")
    .Output("zeroed: int32");

(Tenga en cuenta que el conjunto de tipos de atributo es diferente del tf.DType utilizado para entradas y salidas).

Su núcleo puede acceder a este atributo en su constructor a través del parámetro context :

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
    // Get the index of the value to preserve
    OP_REQUIRES_OK(context,
                   context->GetAttr("preserve_index", &preserve_index_));
    // Check that preserve_index is positive
    OP_REQUIRES(context, preserve_index_ >= 0,
                errors::InvalidArgument("Need preserve_index >= 0, got ",
                                        preserve_index_));
  }
  void Compute(OpKernelContext* context) override {
    // ...
  }
 private:
  int preserve_index_;
};

que luego se puede utilizar en el método Compute :

  void Compute(OpKernelContext* context) override {
    // ...

    // We're using saved attr to validate potentially dynamic input
    // So we check that preserve_index is in range
    OP_REQUIRES(context, preserve_index_ < input.dimension(0),
                errors::InvalidArgument("preserve_index out of range"));

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the requested input value
    output_flat(preserve_index_) = input(preserve_index_);
  }

tipos de atributos

Los siguientes tipos son compatibles en un attr:

  • string : cualquier secuencia de bytes (no se requiere que sea UTF8).
  • int : Un entero con signo.
  • float : Un número de punto flotante.
  • bool : Verdadero o falso.
  • type : uno de los valores (no ref) de DataType .
  • shape : un TensorShapeProto .
  • list(<type>) : Una lista de <type> , donde <type> es uno de los tipos anteriores. Tenga en cuenta que list(list(<type>)) no es válido.

Consulte también: op_def_builder.cc:FinalizeAttr para obtener una lista definitiva.

Valores predeterminados y restricciones

Los atributos pueden tener valores predeterminados y algunos tipos de atributos pueden tener restricciones. Para definir un atributo con restricciones, puede usar los siguientes <attr-type-expr> s:

{'<string1>', '<string2>'} : el valor debe ser una cadena que tenga el valor <string1> o <string2> . El nombre del tipo, string , está implícito cuando usa esta sintaxis. Esto emula una enumeración:

REGISTER_OP("EnumExample")
    .Attr("e: {'apple', 'orange'}");

{<type1>, <type2>} : el valor es de tipo type y debe ser uno de <type1> o <type2> , donde <type1> y <type2> son compatibles tf.DType . No especifica que el tipo de attr sea type . Esto está implícito cuando tienes una lista de tipos en {...} . Por ejemplo, en este caso el attr t es un tipo que debe ser int32 , float o bool :

REGISTER_OP("RestrictedTypeExample")
    .Attr("t: {int32, float, bool}");

Hay atajos para las restricciones de tipo comunes:

  • numbertype : tipo type restringido a los tipos numéricos (no de cadena y no booleanos).
  • realnumbertype : como numbertype sin tipos complejos.
  • quantizedtype : como numbertype pero solo los tipos de números cuantificados.

Las listas específicas de tipos permitidas por estos están definidas por las funciones (como NumberTypes() ) en tensorflow/core/framework/types.h . En este ejemplo, el attr t debe ser uno de los tipos numéricos:

REGISTER_OP("NumberType")
    .Attr("t: numbertype");

Para esta operación:

tf.number_type(t=tf.int32)  # Valid
tf.number_type(t=tf.bool)   # Invalid

Las listas se pueden combinar con otras listas y tipos únicos. La siguiente operación permite que attr t sea cualquiera de los tipos numéricos o el tipo bool:

REGISTER_OP("NumberOrBooleanType")
    .Attr("t: {numbertype, bool}");

Para esta operación:

tf.number_or_boolean_type(t=tf.int32)  # Valid
tf.number_or_boolean_type(t=tf.bool)   # Valid
tf.number_or_boolean_type(t=tf.string) # Invalid

int >= <n> : El valor debe ser un int cuyo valor sea mayor o igual que <n> , donde <n> es un número natural. Por ejemplo, el siguiente registro de operación especifica que el a debe tener un valor de al menos 2 :

REGISTER_OP("MinIntExample")
    .Attr("a: int >= 2");

list(<type>) >= <n> : una lista de tipo <type> cuya longitud es mayor o igual que <n> . Por ejemplo, el siguiente registro op especifica que attr a es una lista de tipos (ya sea int32 o float ), y que debe haber al menos 3 de ellos:

REGISTER_OP("TypeListExample")
    .Attr("a: list({int32, float}) >= 3");

Para establecer un valor predeterminado para un atributo (haciéndolo opcional en el código generado), agregue = <default> al final, como en:

REGISTER_OP("AttrDefaultExample")
    .Attr("i: int = 0");

Además, se pueden especificar tanto una restricción como un valor predeterminado:

REGISTER_OP("AttrConstraintAndDefaultExample")
    .Attr("i: int >= 1 = 1");

La sintaxis admitida del valor predeterminado es lo que se usaría en la representación proto de la definición de GraphDef resultante.

Estos son ejemplos de cómo especificar un valor predeterminado para todos los tipos:

REGISTER_OP("AttrDefaultExampleForAllTypes")
   .Attr("s: string = 'foo'")
   .Attr("i: int = 0")
   .Attr("f: float = 1.0")
   .Attr("b: bool = true")
   .Attr("ty: type = DT_INT32")
   .Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
   .Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
   .Attr("l_empty: list(int) = []")
   .Attr("l_int: list(int) = [2, 3, 5, 7]");

Tenga en cuenta en particular que los valores de tipo type usan tf.DType .

Polimorfismo

Tipo de polimorfismo

Para las operaciones que pueden tomar diferentes tipos como entrada o producir diferentes tipos de salida, puede especificar un atributo en un tipo de entrada o salida en el registro de la operación. Por lo general, registraría un OpKernel para cada tipo admitido.

Por ejemplo, si desea que la operación ZeroOut funcione en float s además de int32 s, su registro de operación podría verse así:

REGISTER_OP("ZeroOut")
    .Attr("T: {float, int32}")
    .Input("to_zero: T")
    .Output("zeroed: T");

Su registro de operaciones ahora especifica que el tipo de entrada debe ser float , o int32 , y que su salida será del mismo tipo, ya que ambos tienen tipo T .

Denominación

Las entradas, salidas y atributos generalmente deben recibir nombres de mayúsculas y minúsculas. La única excepción son los atributos que se utilizan como tipo de entrada o en el tipo de salida. Esos atributos se pueden deducir cuando se agrega la operación al gráfico y, por lo tanto, no aparecen en la función de la operación. Por ejemplo, esta última definición de ZeroOut generará una función de Python que se parece a:

def zero_out(to_zero, name=None):
  """...
  Args:
    to_zero: A `Tensor`. Must be one of the following types:
        `float32`, `int32`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `to_zero`.
  """

Si to_zero se le pasa un tensor int32 , entonces T se establece automáticamente en int32 (bueno, en realidad DT_INT32 ). Esos atributos inferidos reciben nombres en mayúsculas o CamelCase.

Compare esto con una operación que tiene un atributo de tipo que determina el tipo de salida:

REGISTER_OP("StringToNumber")
    .Input("string_tensor: string")
    .Output("output: out_type")
    .Attr("out_type: {float, int32} = DT_FLOAT");
    .Doc(R"doc(
Converts each string in the input Tensor to the specified numeric type.
)doc");

En este caso, el usuario debe especificar el tipo de salida, como en el Python generado:

def string_to_number(string_tensor, out_type=None, name=None):
  """Converts each string in the input Tensor to the specified numeric type.

  Args:
    string_tensor: A `Tensor` of type `string`.
    out_type: An optional `tf.DType` from: `tf.float32, tf.int32`.
      Defaults to `tf.float32`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` of type `out_type`.
  """
Ejemplo de polimorfismo de tipo
#include "tensorflow/core/framework/op_kernel.h"

class ZeroOutInt32Op : public OpKernel {
  // as before
};

class ZeroOutFloatOp : public OpKernel {
 public:
  explicit ZeroOutFloatOp(OpKernelConstruction* context)
      : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<float>();

    // Create an output tensor
    Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<float>();

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value
    if (N > 0) output_flat(0) = input(0);
  }
};

// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<int32>("T"),
    ZeroOutInt32Op);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutFloatOp);

Para preservar la compatibilidad con versiones anteriores , debe especificar un valor predeterminado al agregar un atributo a una operación existente:

REGISTER_OP("ZeroOut")
  .Attr("T: {float, int32} = DT_INT32")
  .Input("to_zero: T")
  .Output("zeroed: T")

Digamos que desea agregar más tipos, digamos double :

REGISTER_OP("ZeroOut")
    .Attr("T: {float, double, int32}")
    .Input("to_zero: T")
    .Output("zeroed: T");

En lugar de escribir otro OpKernel con código redundante como el anterior, a menudo podrá usar una plantilla de C++ en su lugar. Todavía tendrá un registro de kernel (llamada REGISTER_KERNEL_BUILDER ) por sobrecarga.

template <typename T>
class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<T>();

    // Create an output tensor
    Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<T>();

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value
    if (N > 0) output_flat(0) = input(0);
  }
};

// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<int32>("T"),
    ZeroOutOp<int32>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutOp<float>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<double>("T"),
    ZeroOutOp<double>);

Si tiene más de un par de sobrecargas, puede colocar el registro en una macro.

#include "tensorflow/core/framework/op_kernel.h"

#define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)

REGISTER_KERNEL(int32);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);

#undef REGISTER_KERNEL

Dependiendo de la lista de tipos para los que esté registrando el kernel, es posible que pueda usar una macro proporcionada por tensorflow/core/framework/register_types.h :

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"

REGISTER_OP("ZeroOut")
    .Attr("T: realnumbertype")
    .Input("to_zero: T")
    .Output("zeroed: T");

template <typename T>
class ZeroOutOp : public OpKernel { ... };

#define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)

TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);

#undef REGISTER_KERNEL
Lista de entradas y salidas

Además de poder aceptar o producir diferentes tipos, las operaciones pueden consumir o producir un número variable de tensores.

En el siguiente ejemplo, el atributo T contiene una lista de tipos y se usa como el tipo de in y out . La entrada y la salida son listas de tensores de ese tipo (y el número y tipos de tensores en la salida son los mismos que en la entrada, ya que ambos tienen tipo T ).

REGISTER_OP("PolymorphicListExample")
    .Attr("T: list(type)")
    .Input("in: T")
    .Output("out: T");

También puede imponer restricciones sobre qué tipos se pueden especificar en la lista. En el siguiente caso, la entrada es una lista de tensores float y double . La operación acepta, por ejemplo, tipos de entrada (float, double, float) y, en ese caso, el tipo de salida también sería (float, double, float) .

REGISTER_OP("ListTypeRestrictionExample")
    .Attr("T: list({float, double})")
    .Input("in: T")
    .Output("out: T");

Si desea que todos los tensores de una lista sean del mismo tipo, puede hacer algo como:

REGISTER_OP("IntListInputExample")
    .Attr("N: int")
    .Input("in: N * int32")
    .Output("out: int32");

Esto acepta una lista de tensores int32 y usa un int attr N para especificar la longitud de la lista.

Esto también se puede hacer de tipo polimórfico . En el siguiente ejemplo, la entrada es una lista de tensores (con longitud "N" ) del mismo tipo (pero no especificado) ( "T" ), y la salida es un solo tensor de tipo coincidente:

REGISTER_OP("SameListInputExample")
    .Attr("N: int")
    .Attr("T: type")
    .Input("in: N * T")
    .Output("out: T");

De forma predeterminada, las listas de tensores tienen una longitud mínima de 1. Puede cambiar ese valor predeterminado utilizando una restricción ">=" en el attr correspondiente . En el siguiente ejemplo, la entrada es una lista de al menos 2 tensores int32 :

REGISTER_OP("MinLengthIntListExample")
    .Attr("N: int >= 2")
    .Input("in: N * int32")
    .Output("out: int32");

La misma sintaxis funciona con atributos de "list(type)" :

REGISTER_OP("MinimumLengthPolymorphicListExample")
    .Attr("T: list(type) >= 3")
    .Input("in: T")
    .Output("out: T");

Entradas y salidas

Para resumir lo anterior, un registro de operación puede tener múltiples entradas y salidas:

REGISTER_OP("MultipleInsAndOuts")
    .Input("y: int32")
    .Input("z: float")
    .Output("a: string")
    .Output("b: int32");

Cada especificación de entrada o salida tiene la forma:

<name>: <io-type-expr>

donde <name> comienza con una letra y puede estar compuesto por caracteres alfanuméricos y guiones bajos. <io-type-expr> es una de las siguientes expresiones de tipo:

  • <type> , donde <type> es un tipo de entrada admitido (por ejemplo, float , int32 , string ). Esto especifica un solo tensor del tipo dado.

    Ver tf.DType .

    REGISTER_OP("BuiltInTypesExample")
        .Input("integers: int32")
        .Input("complex_numbers: complex64");
    
  • <attr-type> , donde <attr-type> es el nombre de un Attr con tipo type o list(type) (con una posible restricción de tipo). Esta sintaxis permite operaciones polimórficas .

    REGISTER_OP("PolymorphicSingleInput")
        .Attr("T: type")
        .Input("in: T");
    
    REGISTER_OP("RestrictedPolymorphicSingleInput")
        .Attr("T: {int32, int64}")
        .Input("in: T");
    

    Hacer referencia a un atributo de tipo list(type) le permite aceptar una secuencia de tensores.

    REGISTER_OP("ArbitraryTensorSequenceExample")
        .Attr("T: list(type)")
        .Input("in: T")
        .Output("out: T");
    
    REGISTER_OP("RestrictedTensorSequenceExample")
        .Attr("T: list({int32, int64})")
        .Input("in: T")
        .Output("out: T");
    

    Nótese que el número y tipo de tensores en la salida out es el mismo que en la entrada in , ya que ambos son de tipo T .

  • Para una secuencia de tensores con el mismo tipo: <number> * <type> , donde <number> es el nombre de un Attr con tipo int . El <type> puede ser un tf.DType o el nombre de un attr con type type . Como ejemplo del primero, esta operación acepta una lista de tensores int32 :

    REGISTER_OP("Int32SequenceExample")
        .Attr("NumTensors: int")
        .Input("in: NumTensors * int32")
    

    Mientras que esta operación acepta una lista de tensores de cualquier tipo, siempre que sean todos iguales:

    REGISTER_OP("SameTypeSequenceExample")
        .Attr("NumTensors: int")
        .Attr("T: type")
        .Input("in: NumTensors * T")
    
  • Para una referencia a un tensor: Ref(<type>) , donde <type> es uno de los tipos anteriores.

Se deducirá cualquier atributo utilizado en el tipo de una entrada. Por convención, esos atributos inferidos usan nombres en mayúsculas (como T o N ). De lo contrario, las entradas, salidas y atributos tienen nombres como parámetros de función (por ejemplo, num_outputs ). Para obtener más detalles, consulte la sección anterior sobre nombres .

Para obtener más detalles, consulte tensorflow/core/framework/op_def_builder.h .

Compatibilidad al revés

Supongamos que ha escrito una buena operación personalizada y la ha compartido con otros, por lo que tiene clientes satisfechos con su operación. Sin embargo, le gustaría hacer cambios en la operación de alguna manera.

En general, los cambios en las especificaciones registradas existentes deben ser compatibles con versiones anteriores: cambiar la especificación de una operación no debe romper los búferes de protocolo GraphDef serializados anteriores construidos a partir de especificaciones anteriores. Los detalles de la compatibilidad GraphDef se describen aquí .

Hay varias formas de preservar la compatibilidad con versiones anteriores.

  1. Cualquier atributo nuevo que se agregue a una operación debe tener valores predeterminados definidos y, con ese valor predeterminado, la operación debe tener el comportamiento original. Para cambiar una operación de no polimórfica a polimórfica, debe dar un valor predeterminado al nuevo tipo attr para conservar la firma original de forma predeterminada. Por ejemplo, si su operación fuera:

    REGISTER_OP("MyGeneralUnaryOp")
        .Input("in: float")
        .Output("out: float");
    

    puede hacerlo polimórfico de una manera compatible con versiones anteriores usando:

    REGISTER_OP("MyGeneralUnaryOp")
        .Input("in: T")
        .Output("out: T")
        .Attr("T: numerictype = DT_FLOAT");
    
  2. Puede hacer que una restricción en un attr sea menos restrictiva. Por ejemplo, puede cambiar de {int32, int64} a {int32, int64, float} o type . O puede cambiar de {"apple", "orange"} a {"apple", "banana", "orange"} o string .

  3. Puede cambiar entradas/salidas individuales a entradas/salidas de lista, siempre que el valor predeterminado para el tipo de lista coincida con la firma anterior.

  4. Puede agregar una nueva lista de entrada/salida, si por defecto está vacía.

  5. Coloque nombres en cualquier nueva operación que cree, prefijando los nombres de las operaciones con algo único para su proyecto. Esto evita que su operación colisione con cualquier operación que pueda incluirse en futuras versiones de TensorFlow.

  6. ¡Planifica con anticipación! Trate de anticipar usos futuros para la operación. Algunos cambios de firma no se pueden realizar de manera compatible (por ejemplo, convertir una lista del mismo tipo en una lista de tipos diferentes).

La lista completa de cambios seguros e inseguros se puede encontrar en tensorflow/core/framework/op_compatibility_test.cc . Si no puede hacer que su cambio en una operación sea compatible con versiones anteriores, cree una nueva operación con un nuevo nombre con la nueva semántica.

También tenga en cuenta que si bien estos cambios pueden mantener la compatibilidad GraphDef , el código de Python generado puede cambiar de una manera que no sea compatible con las llamadas anteriores. La API de Python puede mantenerse compatible mediante cambios cuidadosos en un contenedor de Python escrito a mano, manteniendo la firma anterior, excepto posiblemente agregando nuevos argumentos opcionales al final. Los cambios generalmente incompatibles solo se pueden realizar cuando TensorFlow cambia las versiones principales y deben cumplir con la semántica de la versión GraphDef .

Soporte de GPU

Puede implementar diferentes OpKernels y registrar uno para CPU y otro para GPU, al igual que puede registrar kernels para diferentes tipos . Hay varios ejemplos de kernels compatibles con GPU en tensorflow/core/kernels/ . Observe que algunos núcleos tienen una versión de CPU en un archivo .cc , una versión de GPU en un archivo que termina en _gpu.cu.cc y algo de código compartido en común en un archivo .h .

Por ejemplo, tf.pad tiene todo menos el kernel de GPU en tensorflow/core/kernels/pad_op.cc . El kernel de GPU está en tensorflow/core/kernels/pad_op_gpu.cu.cc y el código compartido es una clase con plantilla definida en tensorflow/core/kernels/pad_op.h . Organizamos el código de esta manera por dos razones: le permite compartir código común entre las implementaciones de CPU y GPU, y coloca la implementación de GPU en un archivo separado para que solo pueda compilarlo el compilador de GPU.

Una cosa a tener en cuenta, incluso cuando se usa la versión del kernel de GPU de pad , todavía necesita su entrada "paddings" en la memoria de la CPU. Para marcar que las entradas o salidas se mantienen en la CPU, agregue una llamada HostMemory() al registro del kernel, por ejemplo:

#define REGISTER_GPU_KERNEL(T)                         \
  REGISTER_KERNEL_BUILDER(Name("Pad")                  \
                              .Device(DEVICE_GPU)      \
                              .TypeConstraint<T>("T")  \
                              .HostMemory("paddings"), \
                          PadOp<GPUDevice, T>)

Compilando el kernel para el dispositivo GPU

Mire cuda_op_kernel.cu.cc para ver un ejemplo que usa un kernel CUDA para implementar un op. tf_custom_op_library acepta un argumento gpu_srcs en el que se puede especificar la lista de archivos fuente que contienen los núcleos CUDA (archivos *.cu.cc ). Para usar con una instalación binaria de TensorFlow, los núcleos CUDA deben compilarse con el compilador nvcc de NVIDIA. Esta es la secuencia de comandos que puede usar para compilar cuda_op_kernel.cu.cc y cuda_op_kernel.cc en una sola biblioteca cargable dinámicamente:

nvcc -std=c++14 -c -o cuda_op_kernel.cu.o cuda_op_kernel.cu.cc \
  ${TF_CFLAGS[@]} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC

g++ -std=c++14 -shared -o cuda_op_kernel.so cuda_op_kernel.cc \
  cuda_op_kernel.cu.o ${TF_CFLAGS[@]} -fPIC -lcudart ${TF_LFLAGS[@]}

cuda_op_kernel.so producido anteriormente se puede cargar como de costumbre en Python, usando la función tf.load_op_library .

Tenga en cuenta que si sus bibliotecas CUDA no están instaladas en /usr/local/lib64 , deberá especificar la ruta explícitamente en el segundo comando (g++) anterior. Por ejemplo, agregue -L /usr/local/cuda-8.0/lib64/ si su CUDA está instalado en /usr/local/cuda-8.0 .

Implementar el degradado en Python

Dado un gráfico de operaciones, TensorFlow usa la diferenciación automática (propagación hacia atrás) para agregar nuevas operaciones que representan gradientes con respecto a las operaciones existentes. Para hacer que la diferenciación automática funcione para las nuevas operaciones, debe registrar una función de gradiente que calcule los gradientes con respecto a las entradas de las operaciones dados los gradientes con respecto a las salidas de las operaciones.

Matemáticamente, si un op calcula \(y = f(x)\) el gradiente registrado op convierte los gradientes \(\partial L/ \partial y\) de pérdida \(L\) con respecto a\(y\) en gradientes \(\partial L/ \partial x\) con respecto a \(x\) a través de la regla de la cadena:

\[\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial y}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial f}{\partial x}.\]

En el caso de ZeroOut , solo una entrada en la entrada afecta la salida, por lo que el gradiente con respecto a la entrada es un tensor disperso "uno caliente". Esto se expresa de la siguiente manera:

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops

@ops.RegisterGradient("ZeroOut")
def _zero_out_grad(op, grad):
  """The gradients for `zero_out`.

  Args:
    op: The `zero_out` `Operation` that we are differentiating, which we can use
      to find the inputs and outputs of the original op.
    grad: Gradient with respect to the output of the `zero_out` op.

  Returns:
    Gradients with respect to the input of `zero_out`.
  """
  to_zero = op.inputs[0]
  shape = array_ops.shape(to_zero)
  index = array_ops.zeros_like(shape)
  first_grad = array_ops.reshape(grad, [-1])[0]
  to_zero_grad = sparse_ops.sparse_to_dense([index], shape, first_grad, 0)
  return [to_zero_grad]  # List of one Tensor, since we have one input

Detalles sobre el registro de funciones de gradiente con tf.RegisterGradient :

  • Para una operación con una salida, la función de gradiente tomará un tf.Operation , op y un tf.Tensor grad y creará nuevas operaciones a partir de los tensores op.inputs[i] , op.outputs[i] y grad . La información sobre cualquier atributo se puede encontrar a través de tf.Operation.get_attr .

  • Si la operación tiene múltiples salidas, la función de gradiente tomará op y grads , donde grads es una lista de gradientes con respecto a cada salida. El resultado de la función de gradiente debe ser una lista de objetos Tensor que representen los gradientes con respecto a cada entrada.

  • Si no hay un gradiente bien definido para alguna entrada, como las entradas de números enteros que se usan como índices, el gradiente devuelto correspondiente debe ser None . Por ejemplo, para una operación que toma un tensor de coma flotante x y un índice entero i , la función de gradiente return [x_grad, None] .

  • Si no hay ningún gradiente significativo para la operación, a menudo no tendrá que registrar ningún gradiente, y mientras nunca se necesite el gradiente de la operación, estará bien. En algunos casos, una operación no tiene un gradiente bien definido pero puede estar involucrada en el cálculo del gradiente. Aquí puede usar ops.NotDifferentiable para propagar automáticamente ceros hacia atrás.

Tenga en cuenta que en el momento en que se llama a la función de gradiente, solo está disponible el gráfico de flujo de datos de operaciones, no los datos del tensor en sí. Por lo tanto, todos los cálculos deben realizarse utilizando otras operaciones de tensorflow, para ejecutarse en el momento de la ejecución del gráfico.

Agregue sugerencias de tipo al registrar el degradado personalizado para un tipo de operación para que el código sea más legible, depurable, más fácil de mantener y más sólido a través de la validación de datos. Por ejemplo, al tomar una op como parámetro en una función, especifique que la función de gradiente tomará una tf.Operation como su tipo de parámetro.

Funciones de forma en C++

La API de TensorFlow tiene una característica llamada "inferencia de formas" que proporciona información sobre las formas de los tensores sin tener que ejecutar el gráfico. La inferencia de forma es compatible con "funciones de forma" que se registran para cada tipo de operación en la declaración REGISTER_OP de C++ y desempeñan dos funciones: afirmar que las formas de las entradas son compatibles durante la construcción del gráfico y especificar las formas para las salidas.

Las funciones de forma se definen como operaciones en la clase shape_inference::InferenceContext . Por ejemplo, en la función de forma para ZeroOut:

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

c->set_output(0, c->input(0)); declara que la forma de la primera salida debe establecerse en la forma de la primera entrada. Si la salida se selecciona por su índice como en el ejemplo anterior, el segundo parámetro de set_output debe ser un objeto ShapeHandle . Puede crear un objeto ShapeHandle vacío mediante su constructor predeterminado. El objeto ShapeHandle para una entrada con índice idx se puede obtener mediante c->input(idx) .

Hay una serie de funciones de forma comunes que se aplican a muchas operaciones, como shape_inference::UnchangedShape , que se puede encontrar en common_shape_fns.h y se usa de la siguiente manera:

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn(::tensorflow::shape_inference::UnchangedShape);

Una función de forma también puede restringir la forma de una entrada. Para la versión de ZeroOut con una restricción de forma vectorial , la función de forma sería la siguiente:

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      ::tensorflow::shape_inference::ShapeHandle input;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input));
      c->set_output(0, input);
      return Status::OK();
    });

La llamada WithRank valida que la forma de entrada c->input(0) tiene una forma con exactamente una dimensión (o si la forma de entrada es desconocida, la forma de salida será un vector con una dimensión desconocida).

Si su operación es polimórfica con múltiples entradas , puede usar miembros de InferenceContext para determinar la cantidad de formas para verificar y Merge para validar que las formas sean todas compatibles (alternativamente, acceda a los atributos que indican las longitudes, con InferenceContext::GetAttr , que proporciona acceso a los atributos de la op).

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      ::tensorflow::shape_inference::ShapeHandle input;
      ::tensorflow::shape_inference::ShapeHandle output;
      for (size_t i = 0; i < c->num_inputs(); ++i) {
        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &input));
        TF_RETURN_IF_ERROR(c->Merge(output, input, &output));
      }
      c->set_output(0, output);
      return Status::OK();
    });

Dado que la inferencia de forma es una función opcional y las formas de los tensores pueden variar dinámicamente, las funciones de forma deben ser resistentes a la información de forma incompleta para cualquiera de las entradas. El método Merge en InferenceContext permite a la persona que llama afirmar que dos formas son iguales, incluso si una o ambas no tienen información completa. Las funciones de forma están definidas para todas las operaciones principales de TensorFlow y brindan muchos ejemplos de uso diferentes.

La clase InferenceContext tiene una serie de funciones que se pueden usar para definir manipulaciones de funciones de forma. Por ejemplo, puede validar que una dimensión en particular tiene un valor muy específico usando InferenceContext::Dim e InferenceContext::WithValue ; puede especificar que una dimensión de salida sea la suma o el producto de dos dimensiones de entrada mediante InferenceContext::Add e InferenceContext::Multiply . Consulte la clase InferenceContext para conocer todas las diversas manipulaciones de formas que puede especificar. El siguiente ejemplo establece la forma de la primera salida en (n, 3), donde la primera entrada tiene forma (n, ...)

.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
    c->set_output(0, c->Matrix(c->Dim(c->input(0), 0), 3));
    return Status::OK();
});

Si tiene una función de forma complicada, debería considerar agregar una prueba para validar que varias combinaciones de formas de entrada producen las combinaciones de formas de salida esperadas. Puede ver ejemplos de cómo escribir estas pruebas en algunas de nuestras pruebas de operaciones principales . (La sintaxis de INFER_OK e INFER_ERROR es un poco críptica, pero intente ser compacto al representar las especificaciones de forma de entrada y salida en las pruebas. Por ahora, consulte los comentarios que rodean esas pruebas para tener una idea de la especificación de cadena de forma).

Cree un paquete pip para su operación personalizada

Para crear un paquete pip para su operación, consulte el ejemplo de tensorflow/custom-op . Esta guía muestra cómo crear operaciones personalizadas a partir del paquete pip de TensorFlow en lugar de crear TensorFlow desde el origen.