Criar uma operação

Se você quiser criar uma operação que não seja coberta pela biblioteca existente do TensorFlow, recomendamos que primeiro tente escrever a operação em Python como uma composição de operações ou funções Python existentes. Se isso não for possível, você pode criar uma operação C++ personalizada. Há vários motivos pelos quais você pode querer criar uma operação C++ personalizada:

  • Não é fácil ou possível expressar sua operação como uma composição de operações existentes.
  • Não é eficiente expressar sua operação como uma composição de primitivas existentes.
  • Você deseja fundir manualmente uma composição de primitivos que um compilador futuro acharia difícil fundir.

Por exemplo, imagine que você deseja implementar algo como "pooling mediano", semelhante ao operador "MaxPool", mas calculando medianas em janelas deslizantes em vez de valores máximos. Fazer isso usando uma composição de operações pode ser possível (por exemplo, usando ExtractImagePatches e TopK), mas pode não ser tão eficiente em termos de desempenho ou memória quanto uma operação nativa em que você pode fazer algo mais inteligente em uma única operação fundida. Como sempre, vale a pena primeiro tentar expressar o que você deseja usando a composição do operador, optando apenas por adicionar uma nova operação se isso for difícil ou ineficiente.

Para incorporar sua operação personalizada, você precisará:

  1. Registre a nova operação em um arquivo C++. O registro de operação define uma interface (especificação) para a funcionalidade da operação, que é independente da implementação da operação. Por exemplo, o registro de operação define o nome da operação e as entradas e saídas da operação. Ele também define a função de forma que é usada para inferência de forma de tensor.
  2. Implemente a operação em C++. A implementação de um op é conhecida como kernel e é a implementação concreta da especificação que você registrou na Etapa 1. Pode haver vários kernels para diferentes tipos ou arquiteturas de entrada/saída (por exemplo, CPUs, GPUs).
  3. Crie um wrapper Python (opcional). Esse wrapper é a API pública usada para criar a operação em Python. Um wrapper padrão é gerado a partir do registro op, que pode ser usado diretamente ou adicionado a ele.
  4. Escreva uma função para calcular gradientes para a operação (opcional).
  5. Teste a op. Geralmente fazemos isso em Python por conveniência, mas você também pode testar a operação em C++. Se você definir gradientes, poderá verificá-los com Python tf.test.compute_gradient_error . Consulte relu_op_test.py como um exemplo que testa as funções de encaminhamento de operadores semelhantes a Relu e seus gradientes.

Pré-requisitos

Definir a interface operacional

Você define a interface de uma operação registrando-a no sistema TensorFlow. No registro, você especifica o nome do seu op, suas entradas (tipos e nomes) e saídas (tipos e nomes), bem como docstrings e quaisquer atributos que o op possa exigir.

Para ver como isso funciona, suponha que você gostaria de criar um op que recebe um tensor de int32 s e gera uma cópia do tensor, com todos, exceto o primeiro elemento definido como zero. Para fazer isso, crie um arquivo chamado zero_out.cc . Em seguida, adicione uma chamada à macro REGISTER_OP que define a interface para sua operação:

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

Essa operação ZeroOut usa um tensor to_zero de inteiros de 32 bits como entrada e gera um tensor zeroed de inteiros de 32 bits. A operação também usa uma função de forma para garantir que o tensor de saída tenha a mesma forma que o tensor de entrada. Por exemplo, se a entrada for um tensor de forma [10, 20], então esta função de forma especifica que a forma de saída também é [10, 20].

Implemente o kernel para o op

Depois de definir a interface, forneça uma ou mais implementações do op. Para criar um desses kernels, crie uma classe que estenda OpKernel e substitua o método Compute . O método Compute fornece um argumento context do tipo OpKernelContext* , a partir do qual você pode acessar coisas úteis como os tensores de entrada e saída.

Adicione seu kernel ao arquivo que você criou acima. O kernel pode ser algo como isto:

#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // Set all but the first element of the output tensor to 0.
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value if possible.
    if (N > 0) output_flat(0) = input(0);
  }
};

Depois de implementar seu kernel, você o registra no sistema TensorFlow. No registro, você especifica diferentes restrições sob as quais esse kernel será executado. Por exemplo, você pode ter um kernel feito para CPUs e outro separado para GPUs.

Para fazer isso para a operação ZeroOut , adicione o seguinte a zero_out.cc :

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

Kernels de CPU multi-threaded

Para escrever um kernel de CPU multi-threaded, a função Shard em work_sharder.h pode ser usada. Essa função fragmenta uma função de computação entre os encadeamentos configurados para serem usados ​​para encadeamento intra-operacional (consulte intra_op_parallelism_threads em config.proto ).

núcleos da GPU

Um kernel GPU é implementado em duas partes: o OpKernel e o kernel CUDA e seu código de inicialização.

Às vezes, a implementação do OpKernel é comum entre um kernel de CPU e GPU, como na inspeção de entradas e na alocação de saídas. Nesse caso, uma implementação sugerida é:

  1. Defina o OpKernel modelado no Dispositivo e o tipo primitivo do tensor.
  2. Para fazer o cálculo real da saída, a função Compute chama uma estrutura de functor modelo.
  3. A especialização desse functor para o CPUDevice é definida no mesmo arquivo, mas a especialização para o GPUDevice é definida em um arquivo .cu.cc, pois será compilado com o compilador CUDA.

Aqui está um exemplo de implementação.

// kernel_example.h
#ifndef KERNEL_EXAMPLE_H_
#define KERNEL_EXAMPLE_H_

#include <unsupported/Eigen/CXX11/Tensor>

template <typename Device, typename T>
struct ExampleFunctor {
  void operator()(const Device& d, int size, const T* in, T* out);
};

#if GOOGLE_CUDA
// Partially specialize functor for GpuDevice.
template <typename T>
struct ExampleFunctor<Eigen::GpuDevice, T> {
  void operator()(const Eigen::GpuDevice& d, int size, const T* in, T* out);
};
#endif

#endif KERNEL_EXAMPLE_H_
// kernel_example.cc
#include "kernel_example.h"

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;

REGISTER_OP("Example")
    .Attr("T: numbertype")
    .Input("input: T")
    .Output("input_times_two: T")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

// CPU specialization of actual computation.
template <typename T>
struct ExampleFunctor<CPUDevice, T> {
  void operator()(const CPUDevice& d, int size, const T* in, T* out) {
    for (int i = 0; i < size; ++i) {
      out[i] = 2 * in[i];
    }
  }
};

// OpKernel definition.
// template parameter <T> is the datatype of the tensors.
template <typename Device, typename T>
class ExampleOp : public OpKernel {
 public:
  explicit ExampleOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);

    // Create an output tensor
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));

    // Do the computation.
    OP_REQUIRES(context, input_tensor.NumElements() <= tensorflow::kint32max,
                errors::InvalidArgument("Too many elements in tensor"));
    ExampleFunctor<Device, T>()(
        context->eigen_device<Device>(),
        static_cast<int>(input_tensor.NumElements()),
        input_tensor.flat<T>().data(),
        output_tensor->flat<T>().data());
  }
};

// Register the CPU kernels.
#define REGISTER_CPU(T)                                          \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("Example").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
      ExampleOp<CPUDevice, T>);
REGISTER_CPU(float);
REGISTER_CPU(int32);

// Register the GPU kernels.
#ifdef GOOGLE_CUDA
#define REGISTER_GPU(T)                                          \
  /* Declare explicit instantiations in kernel_example.cu.cc. */ \
  extern template class ExampleFunctor<GPUDevice, T>;            \
  REGISTER_KERNEL_BUILDER(                                       \
      Name("Example").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
      ExampleOp<GPUDevice, T>);
REGISTER_GPU(float);
REGISTER_GPU(int32);
#endif  // GOOGLE_CUDA
// kernel_example.cu.cc
#ifdef GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "kernel_example.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"

using namespace tensorflow;

using GPUDevice = Eigen::GpuDevice;

// Define the CUDA kernel.
template <typename T>
__global__ void ExampleCudaKernel(const int size, const T* in, T* out) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
       i += blockDim.x * gridDim.x) {
    out[i] = 2 * __ldg(in + i);
  }
}

// Define the GPU implementation that launches the CUDA kernel.
template <typename T>
void ExampleFunctor<GPUDevice, T>::operator()(
    const GPUDevice& d, int size, const T* in, T* out) {
  // Launch the cuda kernel.
  //
  // See core/util/gpu_kernel_helper.h for example of computing
  // block count and thread_per_block count.
  int block_count = 1024;
  int thread_per_block = 20;
  ExampleCudaKernel<T>
      <<<block_count, thread_per_block, 0, d.stream()>>>(size, in, out);
}

// Explicitly instantiate functors for the types of OpKernels registered.
template struct ExampleFunctor<GPUDevice, float>;
template struct ExampleFunctor<GPUDevice, int32>;

#endif  // GOOGLE_CUDA

Crie a biblioteca operacional

Compile a operação usando o compilador do sistema (instalação binária do TensorFlow)

Você deve ser capaz de compilar zero_out.cc com um compilador C++ como g++ ou clang disponível em seu sistema. O pacote PIP binário instala os arquivos de cabeçalho e a biblioteca necessária para compilar sua operação em locais específicos do sistema. No entanto, a biblioteca python TensorFlow fornece a função get_include para obter o diretório de cabeçalho, e o diretório get_lib tem um objeto compartilhado para vincular. Aqui estão as saídas dessas funções em uma máquina Ubuntu.

$ python
>>> import tensorflow as tf
>>> tf.sysconfig.get_include()
'/usr/local/lib/python3.6/site-packages/tensorflow/include'
>>> tf.sysconfig.get_lib()
'/usr/local/lib/python3.6/site-packages/tensorflow'

Supondo que você tenha g++ instalado, aqui está a sequência de comandos que você pode usar para compilar sua operação em uma biblioteca dinâmica.

TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
g++ -std=c++14 -shared zero_out.cc -o zero_out.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2

No macOS, o sinalizador adicional "-undefined dynamic_lookup" é necessário ao criar o arquivo .so .

Observação sobre a versão gcc >=5 : o gcc usa o novo C++ ABI desde a versão 5 . O TensorFlow 2.8 e anteriores foram criados com gcc4 que usa a ABI mais antiga. Se você estiver usando essas versões do TensorFlow e estiver tentando compilar sua biblioteca op com gcc>=5 , adicione -D_GLIBCXX_USE_CXX11_ABI=0 à linha de comando para tornar a biblioteca compatível com a ABI mais antiga. Os pacotes do TensorFlow 2.9+ são compatíveis com a ABI mais recente por padrão.

Compile a operação usando o bazel (instalação da fonte do TensorFlow)

Se você tiver as fontes do TensorFlow instaladas, poderá usar o sistema de compilação do TensorFlow para compilar seu op. Coloque um arquivo BUILD com a seguinte regra de compilação Bazel no diretório tensorflow/core/user_ops .

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

tf_custom_op_library(
    name = "zero_out.so",
    srcs = ["zero_out.cc"],
)

Execute o seguinte comando para criar zero_out.so .

$ bazel build --config opt //tensorflow/core/user_ops:zero_out.so

Para compilar a operação Example , com o Kernel CUDA, você precisa usar o parâmetro gpu_srcs de tf_custom_op_library . Coloque um arquivo BUILD com a seguinte regra de compilação Bazel em uma nova pasta dentro do diretório tensorflow/core/user_ops (por exemplo, "example_gpu").

load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")

tf_custom_op_library(
    # kernel_example.cc  kernel_example.cu.cc  kernel_example.h
    name = "kernel_example.so",
    srcs = ["kernel_example.h", "kernel_example.cc"],
    gpu_srcs = ["kernel_example.cu.cc", "kernel_example.h"],
)

Execute o seguinte comando para criar kernel_example.so .

$ bazel build --config opt //tensorflow/core/user_ops/example_gpu:kernel_example.so

Use o op em Python

A API Python do TensorFlow fornece a função tf.load_op_library para carregar a biblioteca dinâmica e registrar a operação com a estrutura do TensorFlow. load_op_library retorna um módulo Python que contém os wrappers Python para o op e o kernel. Assim, depois de criar a operação, você pode fazer o seguinte para executá-la a partir do Python:

import tensorflow as tf
zero_out_module = tf.load_op_library('./zero_out.so')
print(zero_out_module.zero_out([[1, 2], [3, 4]]).numpy())

# Prints
array([[1, 0], [0, 0]], dtype=int32)

Lembre-se de que a função gerada receberá um nome snake_case (para cumprir com o PEP8 ). Portanto, se sua operação for denominada ZeroOut nos arquivos C++, a função python será chamada zero_out .

Para tornar o op disponível como uma função regular import de um módulo Python, pode ser útil ter a chamada load_op_library em um arquivo de origem Python da seguinte forma:

import tensorflow as tf

zero_out_module = tf.load_op_library('./zero_out.so')
zero_out = zero_out_module.zero_out

Verifique se a operação funciona

Uma boa maneira de verificar se você implementou com sucesso sua operação é escrever um teste para ela. Crie o arquivo zero_out_op_test.py com o conteúdo:

import tensorflow as tf

class ZeroOutTest(tf.test.TestCase):
  def testZeroOut(self):
    zero_out_module = tf.load_op_library('./zero_out.so')
    with self.test_session():
      result = zero_out_module.zero_out([5, 4, 3, 2, 1])
      self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])

if __name__ == "__main__":
  tf.test.main()

Em seguida, execute seu teste (supondo que você tenha o tensorflow instalado):

$ python zero_out_op_test.py

Crie recursos avançados em sua operação

Agora que você sabe como construir uma operação e implementação básicas (e um tanto restritas), veremos algumas das coisas mais complicadas que você normalmente precisará criar em sua operação. Isso inclui:

Verificações e validação condicional

O exemplo acima assumiu que o op aplicado a um tensor de qualquer forma. E se aplicasse apenas a vetores? Isso significa adicionar uma verificação à implementação do OpKernel acima.

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);

    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
                errors::InvalidArgument("ZeroOut expects a 1-D vector."));
    // ...
  }

Isso afirma que a entrada é um vetor e retorna com o status InvalidArgument definido se não for. A macro OP_REQUIRES recebe três argumentos:

Como alternativa, se você quiser testar se um objeto Status retornado de alguma função é um erro e, se for o caso, retorná-lo, use OP_REQUIRES_OK . Ambas as macros retornam da função em caso de erro.

registro operacional

Atributos

Ops podem ter attrs, cujos valores são definidos quando o op é adicionado a um gráfico. Estes são usados ​​para configurar o op, e seus valores podem ser acessados ​​tanto dentro da implementação do kernel quanto nos tipos de entradas e saídas no registro do op. Prefira usar uma entrada em vez de um atributo quando possível, pois as entradas são mais flexíveis. Isso ocorre porque attrs são constantes e devem ser definidas no momento da construção do gráfico. Em contraste, as entradas são tensores cujos valores podem ser dinâmicos; ou seja, as entradas podem mudar a cada passo, ser definidas usando um feed, etc. Attrs são usados ​​para coisas que não podem ser feitas com entradas: qualquer configuração que afete a assinatura (número ou tipo de entradas ou saídas) ou que não pode t mudar de passo a passo.

Você define um attr quando registra o op, especificando seu nome e tipo usando o método Attr , que espera uma especificação do formulário:

<name>: <attr-type-expr>

onde <name> começa com uma letra e pode ser composto de caracteres alfanuméricos e sublinhados, e <attr-type-expr> é uma expressão de tipo no formato descrito abaixo .

Por exemplo, se você quiser que a operação ZeroOut preserve um índice especificado pelo usuário, em vez de apenas o elemento 0, você pode registrar a operação da seguinte forma:

REGISTER_OP("ZeroOut")
    .Attr("preserve_index: int")
    .Input("to_zero: int32")
    .Output("zeroed: int32");

(Observe que o conjunto de tipos de atributo é diferente do tf.DType usado para entradas e saídas.)

Seu kernel pode então acessar este attr em seu construtor através do parâmetro context :

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {
    // Get the index of the value to preserve
    OP_REQUIRES_OK(context,
                   context->GetAttr("preserve_index", &preserve_index_));
    // Check that preserve_index is positive
    OP_REQUIRES(context, preserve_index_ >= 0,
                errors::InvalidArgument("Need preserve_index >= 0, got ",
                                        preserve_index_));
  }
  void Compute(OpKernelContext* context) override {
    // ...
  }
 private:
  int preserve_index_;
};

que pode então ser usado no método Compute :

  void Compute(OpKernelContext* context) override {
    // ...

    // We're using saved attr to validate potentially dynamic input
    // So we check that preserve_index is in range
    OP_REQUIRES(context, preserve_index_ < input.dimension(0),
                errors::InvalidArgument("preserve_index out of range"));

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the requested input value
    output_flat(preserve_index_) = input(preserve_index_);
  }

tipos de atributos

Os seguintes tipos são suportados em um attr:

  • string : Qualquer sequência de bytes (não precisa ser UTF8).
  • int : Um inteiro com sinal.
  • float : Um número de ponto flutuante.
  • bool : Verdadeiro ou falso.
  • type : um dos valores (não ref) de DataType .
  • shape : A TensorShapeProto .
  • list(<type>) : Uma lista de <type> , onde <type> é um dos tipos acima. Observe que list(list(<type>)) é inválido.

Consulte também: op_def_builder.cc:FinalizeAttr para obter uma lista definitiva.

Valores padrão e restrições

Attrs podem ter valores padrão e alguns tipos de attrs podem ter restrições. Para definir um atributo com restrições, você pode usar os seguintes <attr-type-expr> s:

{'<string1>', '<string2>'} : O valor deve ser uma string que tenha o valor <string1> ou <string2> . O nome do tipo, string , está implícito quando você usa essa sintaxe. Isso emula uma enumeração:

REGISTER_OP("EnumExample")
    .Attr("e: {'apple', 'orange'}");

{<type1>, <type2>} : O valor é do tipo type , e deve ser <type1> ou <type2> , onde <type1> e <type2> são suportados tf.DType . Você não especifica que o tipo do attr é type . Isso está implícito quando você tem uma lista de tipos em {...} . Por exemplo, neste caso o attr t é um tipo que deve ser um int32 , um float ou um bool :

REGISTER_OP("RestrictedTypeExample")
    .Attr("t: {int32, float, bool}");

Existem atalhos para restrições de tipo comuns:

  • numbertype : tipo de type restrito aos tipos numéricos (não string e não bool).
  • realnumbertype : Como numbertype sem tipos complexos.
  • quantizedtype : como numbertype , mas apenas os tipos de números quantizados.

As listas específicas de tipos permitidos por eles são definidas pelas funções (como NumberTypes() ) em tensorflow/core/framework/types.h . Neste exemplo, o t deve ser um dos tipos numéricos:

REGISTER_OP("NumberType")
    .Attr("t: numbertype");

Para esta operação:

tf.number_type(t=tf.int32)  # Valid
tf.number_type(t=tf.bool)   # Invalid

As listas podem ser combinadas com outras listas e tipos únicos. A operação a seguir permite que attr t seja qualquer um dos tipos numéricos ou o tipo bool:

REGISTER_OP("NumberOrBooleanType")
    .Attr("t: {numbertype, bool}");

Para esta operação:

tf.number_or_boolean_type(t=tf.int32)  # Valid
tf.number_or_boolean_type(t=tf.bool)   # Valid
tf.number_or_boolean_type(t=tf.string) # Invalid

int >= <n> : O valor deve ser um int cujo valor seja maior ou igual a <n> , onde <n> é um número natural. Por exemplo, o registro operacional a seguir especifica que attr a deve ter um valor de pelo menos 2 :

REGISTER_OP("MinIntExample")
    .Attr("a: int >= 2");

list(<type>) >= <n> : Uma lista do tipo <type> cujo comprimento é maior ou igual a <n> . Por exemplo, o registro operacional a seguir especifica que attr a é uma lista de tipos ( int32 ou float ) e que deve haver pelo menos 3 deles:

REGISTER_OP("TypeListExample")
    .Attr("a: list({int32, float}) >= 3");

Para definir um valor padrão para um attr (tornando-o opcional no código gerado), adicione = <default> ao final, como em:

REGISTER_OP("AttrDefaultExample")
    .Attr("i: int = 0");

Além disso, tanto uma restrição quanto um valor padrão podem ser especificados:

REGISTER_OP("AttrConstraintAndDefaultExample")
    .Attr("i: int >= 1 = 1");

A sintaxe suportada do valor padrão é o que seria usado na representação proto da definição de GraphDef resultante.

Aqui estão exemplos de como especificar um padrão para todos os tipos:

REGISTER_OP("AttrDefaultExampleForAllTypes")
   .Attr("s: string = 'foo'")
   .Attr("i: int = 0")
   .Attr("f: float = 1.0")
   .Attr("b: bool = true")
   .Attr("ty: type = DT_INT32")
   .Attr("sh: shape = { dim { size: 1 } dim { size: 2 } }")
   .Attr("te: tensor = { dtype: DT_INT32 int_val: 5 }")
   .Attr("l_empty: list(int) = []")
   .Attr("l_int: list(int) = [2, 3, 5, 7]");

Observe em particular que os valores do tipo type usam tf.DType .

Polimorfismo

polimorfismo de tipo

Para operações que podem receber diferentes tipos como entrada ou produzir diferentes tipos de saída, você pode especificar um atributo em um tipo de entrada ou saída no registro de operação. Normalmente, você registraria um OpKernel para cada tipo suportado.

Por exemplo, se você quiser que o ZeroOut op funcione em float s além de int32 s, seu registro op pode se parecer com:

REGISTER_OP("ZeroOut")
    .Attr("T: {float, int32}")
    .Input("to_zero: T")
    .Output("zeroed: T");

Seu registro de operação agora especifica que o tipo da entrada deve ser float , ou int32 , e que sua saída será do mesmo tipo, pois ambas têm tipo T .

Nomenclatura

Entradas, saídas e atributos geralmente devem receber nomes de snake_case. A única exceção são attrs que são usados ​​como o tipo de uma entrada ou no tipo de uma saída. Esses atributos podem ser inferidos quando o op é adicionado ao gráfico e, portanto, não aparecem na função do op. Por exemplo, esta última definição de ZeroOut irá gerar uma função Python que se parece com:

def zero_out(to_zero, name=None):
  """...
  Args:
    to_zero: A `Tensor`. Must be one of the following types:
        `float32`, `int32`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `to_zero`.
  """

Se to_zero for passado um tensor int32 , T será automaticamente definido como int32 (bem, na verdade DT_INT32 ). Esses atributos inferidos recebem nomes em letras maiúsculas ou CamelCase.

Compare isso com um op que tem um tipo attr que determina o tipo de saída:

REGISTER_OP("StringToNumber")
    .Input("string_tensor: string")
    .Output("output: out_type")
    .Attr("out_type: {float, int32} = DT_FLOAT");
    .Doc(R"doc(
Converts each string in the input Tensor to the specified numeric type.
)doc");

Neste caso, o usuário deve especificar o tipo de saída, como no Python gerado:

def string_to_number(string_tensor, out_type=None, name=None):
  """Converts each string in the input Tensor to the specified numeric type.

  Args:
    string_tensor: A `Tensor` of type `string`.
    out_type: An optional `tf.DType` from: `tf.float32, tf.int32`.
      Defaults to `tf.float32`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` of type `out_type`.
  """
Exemplo de polimorfismo de tipo
#include "tensorflow/core/framework/op_kernel.h"

class ZeroOutInt32Op : public OpKernel {
  // as before
};

class ZeroOutFloatOp : public OpKernel {
 public:
  explicit ZeroOutFloatOp(OpKernelConstruction* context)
      : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<float>();

    // Create an output tensor
    Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<float>();

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value
    if (N > 0) output_flat(0) = input(0);
  }
};

// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<int32>("T"),
    ZeroOutInt32Op);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutFloatOp);

Para preservar a compatibilidade com versões anteriores , você deve especificar um valor padrão ao adicionar um atributo a um op existente:

REGISTER_OP("ZeroOut")
  .Attr("T: {float, int32} = DT_INT32")
  .Input("to_zero: T")
  .Output("zeroed: T")

Digamos que você queira adicionar mais tipos, digamos double :

REGISTER_OP("ZeroOut")
    .Attr("T: {float, double, int32}")
    .Input("to_zero: T")
    .Output("zeroed: T");

Em vez de escrever outro OpKernel com código redundante como acima, muitas vezes você poderá usar um modelo C++. Você ainda terá um registro de kernel (chamada REGISTER_KERNEL_BUILDER ) por sobrecarga.

template <typename T>
class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // Grab the input tensor
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<T>();

    // Create an output tensor
    Tensor* output = NULL;
    OP_REQUIRES_OK(context,
                   context->allocate_output(0, input_tensor.shape(), &output));
    auto output_flat = output->template flat<T>();

    // Set all the elements of the output tensor to 0
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      output_flat(i) = 0;
    }

    // Preserve the first input value
    if (N > 0) output_flat(0) = input(0);
  }
};

// Note that TypeConstraint<int32>("T") means that attr "T" (defined
// in the op registration above) must be "int32" to use this template
// instantiation.
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<int32>("T"),
    ZeroOutOp<int32>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<float>("T"),
    ZeroOutOp<float>);
REGISTER_KERNEL_BUILDER(
    Name("ZeroOut")
    .Device(DEVICE_CPU)
    .TypeConstraint<double>("T"),
    ZeroOutOp<double>);

Se você tiver mais de algumas sobrecargas, poderá colocar o registro em uma macro.

#include "tensorflow/core/framework/op_kernel.h"

#define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)

REGISTER_KERNEL(int32);
REGISTER_KERNEL(float);
REGISTER_KERNEL(double);

#undef REGISTER_KERNEL

Dependendo da lista de tipos para os quais você está registrando o kernel, você pode usar uma macro fornecida por tensorflow/core/framework/register_types.h :

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"

REGISTER_OP("ZeroOut")
    .Attr("T: realnumbertype")
    .Input("to_zero: T")
    .Output("zeroed: T");

template <typename T>
class ZeroOutOp : public OpKernel { ... };

#define REGISTER_KERNEL(type)                                       \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("ZeroOut").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
      ZeroOutOp<type>)

TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);

#undef REGISTER_KERNEL
Listar entradas e saídas

Além de poder aceitar ou produzir tipos diferentes, ops pode consumir ou produzir um número variável de tensores.

No próximo exemplo, o attr T contém uma lista de tipos e é usado como o tipo de in e out . A entrada e a saída são listas de tensores desse tipo (e o número e os tipos de tensores na saída são os mesmos da entrada, pois ambos são do tipo T ).

REGISTER_OP("PolymorphicListExample")
    .Attr("T: list(type)")
    .Input("in: T")
    .Output("out: T");

Você também pode colocar restrições sobre quais tipos podem ser especificados na lista. Neste próximo caso, a entrada é uma lista de tensores float e double . O op aceita, por exemplo, tipos de entrada (float, double, float) e nesse caso o tipo de saída também seria (float, double, float) .

REGISTER_OP("ListTypeRestrictionExample")
    .Attr("T: list({float, double})")
    .Input("in: T")
    .Output("out: T");

Se você quiser que todos os tensores em uma lista sejam do mesmo tipo, você pode fazer algo como:

REGISTER_OP("IntListInputExample")
    .Attr("N: int")
    .Input("in: N * int32")
    .Output("out: int32");

Isso aceita uma lista de tensores int32 e usa um int attr N para especificar o comprimento da lista.

Isso também pode ser feito de tipo polimórfico . No próximo exemplo, a entrada é uma lista de tensores (com comprimento "N" ) do mesmo tipo (mas não especificado) ( "T" ) e a saída é um único tensor de tipo correspondente:

REGISTER_OP("SameListInputExample")
    .Attr("N: int")
    .Attr("T: type")
    .Input("in: N * T")
    .Output("out: T");

Por padrão, as listas de tensores têm um comprimento mínimo de 1. Você pode alterar esse padrão usando uma restrição ">=" no attr correspondente . Neste próximo exemplo, a entrada é uma lista de pelo menos 2 tensores int32 :

REGISTER_OP("MinLengthIntListExample")
    .Attr("N: int >= 2")
    .Input("in: N * int32")
    .Output("out: int32");

A mesma sintaxe funciona com os atributos "list(type)" :

REGISTER_OP("MinimumLengthPolymorphicListExample")
    .Attr("T: list(type) >= 3")
    .Input("in: T")
    .Output("out: T");

Entradas e saídas

Para resumir o que foi dito acima, um registro operacional pode ter várias entradas e saídas:

REGISTER_OP("MultipleInsAndOuts")
    .Input("y: int32")
    .Input("z: float")
    .Output("a: string")
    .Output("b: int32");

Cada especificação de entrada ou saída é da forma:

<name>: <io-type-expr>

onde <name> começa com uma letra e pode ser composto de caracteres alfanuméricos e sublinhados. <io-type-expr> é uma das seguintes expressões de tipo:

  • <type> , onde <type> é um tipo de entrada suportado (por exemplo, float , int32 , string ). Isso especifica um único tensor do tipo fornecido.

    Veja tf.DType .

    REGISTER_OP("BuiltInTypesExample")
        .Input("integers: int32")
        .Input("complex_numbers: complex64");
    
  • <attr-type> , onde <attr-type> é o nome de um Attr com tipo de type ou list(type) (com uma possível restrição de tipo). Essa sintaxe permite operações polimórficas .

    REGISTER_OP("PolymorphicSingleInput")
        .Attr("T: type")
        .Input("in: T");
    
    REGISTER_OP("RestrictedPolymorphicSingleInput")
        .Attr("T: {int32, int64}")
        .Input("in: T");
    

    Referenciar um atributo do tipo list(type) permite que você aceite uma sequência de tensores.

    REGISTER_OP("ArbitraryTensorSequenceExample")
        .Attr("T: list(type)")
        .Input("in: T")
        .Output("out: T");
    
    REGISTER_OP("RestrictedTensorSequenceExample")
        .Attr("T: list({int32, int64})")
        .Input("in: T")
        .Output("out: T");
    

    Observe que o número e os tipos de tensores na saída out são os mesmos da entrada in , pois ambos são do tipo T .

  • Para uma sequência de tensores com o mesmo tipo: <number> * <type> , onde <number> é o nome de um Attr com tipo int . O <type> pode ser um tf.DType , ou o nome de um attr com o tipo type . Como exemplo do primeiro, esta operação aceita uma lista de tensores int32 :

    REGISTER_OP("Int32SequenceExample")
        .Attr("NumTensors: int")
        .Input("in: NumTensors * int32")
    

    Considerando que esta operação aceita uma lista de tensores de qualquer tipo, desde que sejam todos iguais:

    REGISTER_OP("SameTypeSequenceExample")
        .Attr("NumTensors: int")
        .Attr("T: type")
        .Input("in: NumTensors * T")
    
  • Para uma referência a um tensor: Ref(<type>) , onde <type> é um dos tipos anteriores.

Qualquer atributo usado no tipo de uma entrada será inferido. Por convenção, esses atributos inferidos usam nomes maiúsculos (como T ou N ). Caso contrário, entradas, saídas e atributos têm nomes como parâmetros de função (por exemplo, num_outputs ). Para obter mais detalhes, consulte a seção anterior sobre como nomear .

Para obter mais detalhes, consulte tensorflow/core/framework/op_def_builder.h .

Compatibilidade com versões anteriores

Vamos supor que você tenha escrito uma operação legal e personalizada e a tenha compartilhado com outras pessoas, para que você tenha clientes satisfeitos usando sua operação. No entanto, você gostaria de fazer alterações na operação de alguma forma.

Em geral, as alterações nas especificações verificadas existentes devem ser compatíveis com versões anteriores: alterar a especificação de uma operação não deve interromper os buffers de protocolo GraphDef serializados anteriores construídos a partir de especificações mais antigas. Os detalhes da compatibilidade GraphDef são descritos aqui .

Existem várias maneiras de preservar a compatibilidade com versões anteriores.

  1. Quaisquer novos atributos adicionados a uma operação devem ter valores padrão definidos e, com esse valor padrão, a operação deve ter o comportamento original. Para alterar uma operação de não polimórfica para polimórfica, você deve fornecer um valor padrão ao novo tipo attr para preservar a assinatura original por padrão. Por exemplo, se sua operação foi:

    REGISTER_OP("MyGeneralUnaryOp")
        .Input("in: float")
        .Output("out: float");
    

    você pode torná-lo polimórfico de maneira compatível com versões anteriores usando:

    REGISTER_OP("MyGeneralUnaryOp")
        .Input("in: T")
        .Output("out: T")
        .Attr("T: numerictype = DT_FLOAT");
    
  2. Você pode seguramente tornar uma restrição em um attr menos restritiva. Por exemplo, você pode alterar de {int32, int64} para {int32, int64, float} ou type . Ou você pode mudar de {"apple", "orange"} para {"apple", "banana", "orange"} ou string .

  3. Você pode alterar entradas/saídas individuais para entradas/saídas de lista, desde que o padrão para o tipo de lista corresponda à assinatura antiga.

  4. Você pode adicionar uma nova entrada/saída de lista, se o padrão for vazio.

  5. Namespace quaisquer novas operações que você criar, prefixando os nomes das operações com algo exclusivo para seu projeto. Isso evita que sua operação colida com quaisquer operações que possam ser incluídas em versões futuras do TensorFlow.

  6. Planejar com antecedência! Tente antecipar usos futuros para o op. Algumas alterações de assinatura não podem ser feitas de maneira compatível (por exemplo, transformar uma lista do mesmo tipo em uma lista de tipos variados).

A lista completa de alterações seguras e inseguras pode ser encontrada em tensorflow/core/framework/op_compatibility_test.cc . Se você não puder fazer sua alteração em uma operação compatível com versões anteriores, crie uma nova operação com um novo nome com a nova semântica.

Observe também que, embora essas alterações possam manter a compatibilidade GraphDef , o código Python gerado pode ser alterado de maneira incompatível com os chamadores antigos. A API do Python pode ser mantida compatível por meio de alterações cuidadosas em um wrapper Python escrito à mão, mantendo a assinatura antiga, exceto possivelmente adicionando novos argumentos opcionais ao final. Alterações geralmente incompatíveis só podem ser feitas quando o TensorFlow alterar as versões principais e devem estar em conformidade com a semântica da versão GraphDef .

suporte GPU

Você pode implementar diferentes OpKernels e registrar um para CPU e outro para GPU, assim como você pode registrar kernels para diferentes tipos . Existem vários exemplos de kernels com suporte a GPU em tensorflow/core/kernels/ . Observe que alguns kernels têm uma versão de CPU em um arquivo .cc , uma versão de GPU em um arquivo que termina em _gpu.cu.cc e algum código compartilhado em um arquivo .h .

Por exemplo, o tf.pad tem tudo menos o kernel da GPU em tensorflow/core/kernels/pad_op.cc . O kernel da GPU está em tensorflow/core/kernels/pad_op_gpu.cu.cc , e o código compartilhado é uma classe de modelo definida em tensorflow/core/kernels/pad_op.h . Organizamos o código dessa maneira por dois motivos: permite que você compartilhe código comum entre as implementações de CPU e GPU e coloca a implementação de GPU em um arquivo separado para que possa ser compilado apenas pelo compilador de GPU.

Uma coisa a observar, mesmo quando a versão do pad do kernel da GPU é usada, ela ainda precisa de sua entrada "paddings" na memória da CPU. Para marcar que as entradas ou saídas são mantidas na CPU, adicione uma chamada HostMemory() ao registro do kernel, por exemplo:

#define REGISTER_GPU_KERNEL(T)                         \
  REGISTER_KERNEL_BUILDER(Name("Pad")                  \
                              .Device(DEVICE_GPU)      \
                              .TypeConstraint<T>("T")  \
                              .HostMemory("paddings"), \
                          PadOp<GPUDevice, T>)

Compilando o kernel para o dispositivo GPU

Veja cuda_op_kernel.cu.cc para um exemplo que usa um kernel CUDA para implementar um op. A tf_custom_op_library aceita um argumento gpu_srcs no qual a lista de arquivos de origem contendo os kernels CUDA (arquivos *.cu.cc ) pode ser especificada. Para uso com uma instalação binária do TensorFlow, os kernels CUDA devem ser compilados com o compilador nvcc da NVIDIA. Aqui está a sequência de comandos que você pode usar para compilar cuda_op_kernel.cu.cc e cuda_op_kernel.cc em uma única biblioteca carregável dinamicamente:

nvcc -std=c++14 -c -o cuda_op_kernel.cu.o cuda_op_kernel.cu.cc \
  ${TF_CFLAGS[@]} -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC

g++ -std=c++14 -shared -o cuda_op_kernel.so cuda_op_kernel.cc \
  cuda_op_kernel.cu.o ${TF_CFLAGS[@]} -fPIC -lcudart ${TF_LFLAGS[@]}

cuda_op_kernel.so produzido acima pode ser carregado normalmente em Python, usando a função tf.load_op_library .

Observe que, se suas bibliotecas CUDA não estiverem instaladas em /usr/local/lib64 , você precisará especificar o caminho explicitamente no segundo comando (g++) acima. Por exemplo, adicione -L /usr/local/cuda-8.0/lib64/ se seu CUDA estiver instalado em /usr/local/cuda-8.0 .

Implemente o gradiente em Python

Dado um gráfico de operações, o TensorFlow usa diferenciação automática (backpropagation) para adicionar novas operações que representam gradientes em relação às operações existentes. Para fazer a diferenciação automática funcionar para novas operações, você deve registrar uma função de gradiente que calcula os gradientes em relação às entradas das operações, dados os gradientes em relação às saídas das operações.

Matematicamente, se um op calcula \(y = f(x)\) o gradiente registrado op converte gradientes \(\partial L/ \partial y\) de perda \(L\) em relação a\(y\) em gradientes \(\partial L/ \partial x\) em relação a \(x\) por meio da regra da cadeia:

\[\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial y}{\partial x} = \frac{\partial L}{\partial y} \frac{\partial f}{\partial x}.\]

No caso de ZeroOut , apenas uma entrada na entrada afeta a saída, portanto, o gradiente em relação à entrada é um tensor "um quente" esparso. Isso é expresso da seguinte forma:

from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops

@ops.RegisterGradient("ZeroOut")
def _zero_out_grad(op, grad):
  """The gradients for `zero_out`.

  Args:
    op: The `zero_out` `Operation` that we are differentiating, which we can use
      to find the inputs and outputs of the original op.
    grad: Gradient with respect to the output of the `zero_out` op.

  Returns:
    Gradients with respect to the input of `zero_out`.
  """
  to_zero = op.inputs[0]
  shape = array_ops.shape(to_zero)
  index = array_ops.zeros_like(shape)
  first_grad = array_ops.reshape(grad, [-1])[0]
  to_zero_grad = sparse_ops.sparse_to_dense([index], shape, first_grad, 0)
  return [to_zero_grad]  # List of one Tensor, since we have one input

Detalhes sobre o registro de funções de gradiente com tf.RegisterGradient :

  • Para um op com uma saída, a função de gradiente pegará um tf.Operation , op e um tf.Tensor grad e criará novos ops a partir dos tensores op.inputs[i] , op.outputs[i] e grad . Informações sobre quaisquer attrs podem ser encontradas via tf.Operation.get_attr .

  • Se o op tiver várias saídas, a função gradiente receberá op e grads , onde grads é uma lista de gradientes em relação a cada saída. O resultado da função de gradiente deve ser uma lista de objetos Tensor representando os gradientes em relação a cada entrada.

  • Se não houver gradiente bem definido para alguma entrada, como para entradas inteiras usadas como índices, o gradiente retornado correspondente deve ser None . Por exemplo, para uma operação usando um tensor de ponto flutuante x e um índice inteiro i , a função de gradiente return [x_grad, None] .

  • Se não houver nenhum gradiente significativo para a operação, geralmente você não precisará registrar nenhum gradiente e, desde que o gradiente da operação nunca seja necessário, você ficará bem. Em alguns casos, um op não possui um gradiente bem definido, mas pode estar envolvido no cálculo do gradiente. Aqui você pode usar ops.NotDifferentiable para propagar zeros automaticamente para trás.

Observe que, no momento em que a função gradiente é chamada, apenas o gráfico de fluxo de dados de ops está disponível, não os próprios dados do tensor. Assim, toda a computação deve ser realizada usando outras operações do tensorflow, a serem executadas no tempo de execução do gráfico.

Adicione dicas de tipo ao registrar o gradiente personalizado para um tipo de operação para tornar o código mais legível, depurável, mais fácil de manter e mais robusto por meio da validação de dados. Por exemplo, ao usar um op como parâmetro em uma função, especifique que a função gradiente usará um tf.Operation como seu tipo de parâmetro.

Funções de forma em C++

A API do TensorFlow possui um recurso chamado "shape inference" que fornece informações sobre as formas dos tensores sem a necessidade de executar o gráfico. A inferência de forma é suportada por "funções de forma" que são registradas para cada tipo op na declaração C++ REGISTER_OP e executam duas funções: afirmar que as formas das entradas são compatíveis durante a construção do gráfico e especificar as formas para as saídas.

As funções de forma são definidas como operações na classe shape_inference::InferenceContext . Por exemplo, na função de forma para ZeroOut:

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

c->set_output(0, c->input(0)); declara que a forma da primeira saída deve ser definida como a forma da primeira entrada. Se a saída for selecionada por seu índice como no exemplo acima, o segundo parâmetro de set_output deve ser um objeto ShapeHandle . Você pode criar um objeto ShapeHandle vazio por seu construtor padrão. O objeto ShapeHandle para uma entrada com índice idx pode ser obtido por c->input(idx) .

Há várias funções de forma comuns que se aplicam a muitas operações, como shape_inference::UnchangedShape , que pode ser encontrada em common_shape_fns.h e usada da seguinte maneira:

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn(::tensorflow::shape_inference::UnchangedShape);

Uma função de forma também pode restringir a forma de uma entrada. Para a versão de ZeroOut com uma restrição de forma vetorial , a função de forma seria a seguinte:

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      ::tensorflow::shape_inference::ShapeHandle input;
      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input));
      c->set_output(0, input);
      return Status::OK();
    });

A chamada WithRank valida que a forma de entrada c->input(0) tem uma forma com exatamente uma dimensão (ou se a forma de entrada for desconhecida, a forma de saída será um vetor com uma dimensão desconhecida).

Se sua operação for polimórfica com várias entradas , você pode usar membros de InferenceContext para determinar o número de formas a serem verificadas e Merge para validar se as formas são todas compatíveis (alternativamente, acesse atributos que indicam os comprimentos, com InferenceContext::GetAttr , que fornece acesso aos atributos do op).

    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      ::tensorflow::shape_inference::ShapeHandle input;
      ::tensorflow::shape_inference::ShapeHandle output;
      for (size_t i = 0; i < c->num_inputs(); ++i) {
        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &input));
        TF_RETURN_IF_ERROR(c->Merge(output, input, &output));
      }
      c->set_output(0, output);
      return Status::OK();
    });

Como a inferência de forma é um recurso opcional e as formas dos tensores podem variar dinamicamente, as funções de forma devem ser robustas para informações de forma incompletas para qualquer uma das entradas. O método Merge em InferenceContext permite que o chamador assegure que duas formas são iguais, mesmo que uma ou ambas não tenham informações completas. As funções de forma são definidas para todas as operações principais do TensorFlow e fornecem muitos exemplos de uso diferentes.

A classe InferenceContext tem várias funções que podem ser usadas para definir manipulações de função de forma. Por exemplo, você pode validar se uma dimensão específica tem um valor muito específico usando InferenceContext::Dim e InferenceContext::WithValue ; você pode especificar que uma dimensão de saída é a soma/produto de duas dimensões de entrada usando InferenceContext::Add e InferenceContext::Multiply . Consulte a classe InferenceContext para todas as várias manipulações de forma que você pode especificar. O exemplo a seguir define a forma da primeira saída como (n, 3), onde a primeira entrada tem a forma (n, ...)

.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
    c->set_output(0, c->Matrix(c->Dim(c->input(0), 0), 3));
    return Status::OK();
});

Se você tiver uma função de forma complicada, considere adicionar um teste para validar que várias combinações de forma de entrada produzem as combinações de forma de saída esperadas. Você pode ver exemplos de como escrever esses testes em alguns de nossos testes de operações principais . (A sintaxe de INFER_OK e INFER_ERROR é um pouco enigmática, mas tente ser compacto ao representar as especificações de formato de entrada e saída nos testes. Por enquanto, veja os comentários ao redor desses testes para ter uma noção da especificação da string de formato).

Crie um pacote pip para sua operação personalizada

Para criar um pacote pip para sua operação, consulte o exemplo tensorflow/custom-op . Este guia mostra como criar operações personalizadas a partir do pacote pip do TensorFlow em vez de compilar o TensorFlow a partir do código-fonte.