Trang này được dịch bởi Cloud Translation API.
Switch to English

Sử dụng trình biên dịch AOT

Tfcompile là gì?

tfcompile là một công cụ độc lập đi trước thời gian (AOT) biên dịch các biểu đồ TensorFlow thành mã thực thi. Nó có thể giảm tổng kích thước nhị phân và cũng tránh được một số chi phí thời gian chạy. Một trường hợp sử dụng điển hình của tfcompile là biên dịch biểu đồ suy luận thành mã thực thi cho thiết bị di động.

Biểu đồ TensorFlow thường được thực thi bởi thời gian chạy TensorFlow. Điều này phát sinh một số chi phí thời gian chạy để thực hiện từng nút trong biểu đồ. Điều này cũng dẫn đến tổng kích thước nhị phân lớn hơn, vì mã cho thời gian chạy TensorFlow cần phải có sẵn, ngoài chính biểu đồ. Mã thực thi được tạo bởi tfcompile không sử dụng thời gian chạy TensorFlow và chỉ có các phụ thuộc vào các hạt nhân thực sự được sử dụng trong tính toán.

Trình biên dịch được xây dựng trên đỉnh của khung XLA. Mã cầu nối TensorFlow với khung XLA nằm trong trình biên dịch / trình biên dịch .

Tfcompile làm gì?

tfcompile lấy một sơ đồ con, được xác định bởi các khái niệm thức ăn và tìm nạp của TensorFlow và tạo ra một hàm thực hiện sơ đồ con đó. Các feeds là các đối số đầu vào cho hàm và các fetches là các đối số đầu ra cho hàm. Tất cả các đầu vào phải được chỉ định đầy đủ bởi các nguồn cấp dữ liệu; sơ đồ con được cắt tỉa kết quả có thể chứa các nút Placeholder hoặc Biến. Thông thường chỉ định tất cả các Placeholder và Biến là nguồn cấp dữ liệu, điều này đảm bảo sơ đồ con kết quả không còn chứa các nút này. Hàm được tạo được đóng gói dưới dạng cc_library , với tệp tiêu đề xuất chữ ký hàm và tệp đối tượng có chứa triển khai. Người dùng viết mã để gọi hàm được tạo khi thích hợp.

Sử dụng tfcompile

Phần này nêu chi tiết các bước ở mức cao để tạo nhị phân thực thi với tfcompile từ sơ đồ con TensorFlow. Các bước là:

  • Bước 1: Cấu hình sơ đồ con để biên dịch
  • Bước 2: Sử dụng macro xây dựng tf_library để biên dịch sơ đồ con
  • Bước 3: Viết mã để gọi sơ đồ con
  • Bước 4: Tạo nhị phân cuối cùng

Bước 1: Cấu hình sơ đồ con để biên dịch

Xác định các nguồn cấp dữ liệu và tìm nạp tương ứng với các đối số đầu vào và đầu ra cho hàm được tạo. Sau đó định cấu hình các feedsfetches trong một proto tensorflow.tf2xla.Config .

 # Each feed is a positional input argument for the generated function.  The order
# of each entry matches the order of each input argument.  Here “x_hold” and “y_hold”
# refer to the names of placeholder nodes defined in the graph.
feed {
  id { node_name: "x_hold" }
  shape {
    dim { size: 2 }
    dim { size: 3 }
  }
}
feed {
  id { node_name: "y_hold" }
  shape {
    dim { size: 3 }
    dim { size: 2 }
  }
}

# Each fetch is a positional output argument for the generated function.  The order
# of each entry matches the order of each output argument.  Here “x_y_prod”
# refers to the name of a matmul node defined in the graph.
fetch {
  id { node_name: "x_y_prod" }
}
 

Bước 2: Sử dụng macro xây dựng tf_l Library để biên dịch sơ đồ con

Bước này chuyển đổi biểu đồ thành cc_library bằng cách sử dụng macro xây dựng tf_library . cc_library bao gồm một tệp đối tượng chứa mã được tạo từ biểu đồ, cùng với tệp tiêu đề cung cấp quyền truy cập vào mã được tạo. tf_library sử dụng tfcompile để biên dịch biểu đồ TensorFlow thành mã thực thi.

 load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")

# Use the tf_library macro to compile your graph into executable code.
tf_library(
    # name is used to generate the following underlying build rules:
    # <name>           : cc_library packaging the generated header and object files
    # <name>_test      : cc_test containing a simple test and benchmark
    # <name>_benchmark : cc_binary containing a stand-alone benchmark with minimal deps;
    #                    can be run on a mobile device
    name = "test_graph_tfmatmul",
    # cpp_class specifies the name of the generated C++ class, with namespaces allowed.
    # The class will be generated in the given namespace(s), or if no namespaces are
    # given, within the global namespace.
    cpp_class = "foo::bar::MatMulComp",
    # graph is the input GraphDef proto, by default expected in binary format.  To
    # use the text format instead, just use the ‘.pbtxt’ suffix.  A subgraph will be
    # created from this input graph, with feeds as inputs and fetches as outputs.
    # No Placeholder or Variable ops may exist in this subgraph.
    graph = "test_graph_tfmatmul.pb",
    # config is the input Config proto, by default expected in binary format.  To
    # use the text format instead, use the ‘.pbtxt’ suffix.  This is where the
    # feeds and fetches were specified above, in the previous step.
    config = "test_graph_tfmatmul.config.pbtxt",
)
 

Để tạo proto GraphDef (test_graph_tfmatmul.pb) cho ví dụ này, hãy chạy make_test_graphs.py và chỉ định vị trí đầu ra với cờ --out_dir.

Các biểu đồ điển hình chứa các Variables đại diện cho các trọng số được học thông qua đào tạo, nhưng tfcompile không thể biên dịch một sơ đồ con có chứa các Variables . Công cụ freeze_graph.py chuyển đổi các biến thành hằng số, sử dụng các giá trị được lưu trữ trong tệp điểm kiểm tra. Để thuận tiện, macro tf_library hỗ trợ đối số freeze_checkpoint , chạy công cụ này. Để biết thêm ví dụ, hãy xem tenorflow / trình biên dịch / aot / tests / BUILD .

Các hằng số hiển thị trong sơ đồ con đã biên dịch được biên dịch trực tiếp vào mã được tạo. Để truyền các hằng vào hàm được tạo, thay vì biên dịch chúng, chỉ cần chuyển chúng vào dưới dạng nguồn cấp dữ liệu.

Để biết chi tiết về macro xây dựng tf_library , xem tfcompile.bzl .

Để biết chi tiết về công cụ tfcompile cơ bản, xem tfcompile_main.cc .

Bước 3: Viết mã để gọi sơ đồ con

Bước này sử dụng tệp tiêu đề ( test_graph_tfmatmul.h ) được tạo bởi macro xây dựng tf_library trong bước trước để gọi mã được tạo. Tệp tiêu đề nằm trong thư mục bazel-bin tương ứng với gói xây dựng và được đặt tên dựa trên thuộc tính tên được đặt trên macro xây dựng tf_library . Ví dụ: tiêu đề được tạo cho test_graph_tfmatmul sẽ là test_graph_tfmatmul.h . Dưới đây là một phiên bản rút gọn của những gì được tạo ra. Tệp được tạo, trong bazel-bin , chứa các nhận xét hữu ích bổ sung.

 namespace foo {
namespace bar {

// MatMulComp represents a computation previously specified in a
// TensorFlow graph, now compiled into executable code.
class MatMulComp {
 public:
  // AllocMode controls the buffer allocation mode.
  enum class AllocMode {
    ARGS_RESULTS_AND_TEMPS,  // Allocate arg, result and temp buffers
    RESULTS_AND_TEMPS_ONLY,  // Only allocate result and temp buffers
  };

  MatMulComp(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS);
  ~MatMulComp();

  // Runs the computation, with inputs read from arg buffers, and outputs
  // written to result buffers. Returns true on success and false on failure.
  bool Run();

  // Arg methods for managing input buffers. Buffers are in row-major order.
  // There is a set of methods for each positional argument.
  void** args();

  void set_arg0_data(float* data);
  float* arg0_data();
  float& arg0(size_t dim0, size_t dim1);

  void set_arg1_data(float* data);
  float* arg1_data();
  float& arg1(size_t dim0, size_t dim1);

  // Result methods for managing output buffers. Buffers are in row-major order.
  // Must only be called after a successful Run call. There is a set of methods
  // for each positional result.
  void** results();


  float* result0_data();
  float& result0(size_t dim0, size_t dim1);
};

}  // end namespace bar
}  // end namespace foo
 

Lớp C ++ được tạo ra được gọi là MatMulComp trong không gian tên foo::bar , bởi vì đó là cpp_class được chỉ định trong macro tf_library . Tất cả các lớp được tạo có một API tương tự, với sự khác biệt duy nhất là các phương thức để xử lý bộ đệm arg và kết quả. Các phương thức đó khác nhau dựa trên số lượng và loại bộ đệm, được chỉ định bởi feedfetch đối số cho macro tf_library .

Có ba loại bộ đệm được quản lý trong lớp được tạo: args đại diện cho đầu vào, results đại diện cho đầu ra và temps đại diện cho bộ đệm tạm thời được sử dụng nội bộ để thực hiện tính toán. Theo mặc định, mỗi phiên bản của lớp được tạo sẽ phân bổ và quản lý tất cả các bộ đệm này cho bạn. Đối số AllocMode tạo AllocMode có thể được sử dụng để thay đổi hành vi này. Tất cả các bộ đệm được căn chỉnh theo ranh giới 64 byte.

Lớp C ++ được tạo chỉ là một trình bao bọc xung quanh mã cấp thấp được tạo bởi XLA.

Ví dụ về việc gọi hàm được tạo dựa trên tfcompile_test.cc :

 #define EIGEN_USE_THREADS
#define EIGEN_USE_CUSTOM_THREAD_POOL

#include <iostream>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" // generated

int main(int argc, char** argv) {
  Eigen::ThreadPool tp(2);  // Size the thread pool as appropriate.
  Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());


  foo::bar::MatMulComp matmul;
  matmul.set_thread_pool(&device);

  // Set up args and run the computation.
  const float args[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
  std::copy(args + 0, args + 6, matmul.arg0_data());
  std::copy(args + 6, args + 12, matmul.arg1_data());
  matmul.Run();

  // Check result
  if (matmul.result0(0, 0) == 58) {
    std::cout << "Success" << std::endl;
  } else {
    std::cout << "Failed. Expected value 58 at 0,0. Got:"
              << matmul.result0(0, 0) << std::endl;
  }

  return 0;
}
 

Bước 4: Tạo nhị phân cuối cùng

Bước này kết hợp thư viện được tạo bởi tf_library ở bước 2 và mã được viết ở bước 3 để tạo nhị phân cuối cùng. Dưới đây là một ví dụ bazel tập tin BUILD.

 # Example of linking your binary
# Also see //tensorflow/compiler/aot/tests/BUILD
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")

# The same tf_library call from step 2 above.
tf_library(
    name = "test_graph_tfmatmul",
    ...
)

# The executable code generated by tf_library can then be linked into your code.
cc_binary(
    name = "my_binary",
    srcs = [
        "my_code.cc",  # include test_graph_tfmatmul.h to access the generated header
    ],
    deps = [
        ":test_graph_tfmatmul",  # link in the generated object file
        "//third_party/eigen3",
    ],
    linkopts = [
          "-lpthread",
    ]
)