Google I / O là một kết quả hoàn hảo! Cập nhật các phiên TensorFlow Xem phiên

Giới thiệu về đồ thị và hàm tf.

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Tổng quat

Hướng dẫn này nằm bên dưới bề mặt của TensorFlow và Keras để chứng minh cách hoạt động của TensorFlow. Thay vào đó, nếu bạn muốn bắt đầu ngay với Keras, hãy xem bộ sưu tập các hướng dẫn về Keras .

Trong hướng dẫn này, bạn sẽ tìm hiểu cách TensorFlow cho phép bạn thực hiện các thay đổi đơn giản đối với mã của mình để lấy biểu đồ, cách biểu đồ được lưu trữ và biểu diễn cũng như cách bạn có thể sử dụng chúng để tăng tốc mô hình của mình.

Đây là một bức tranh tổng quan bao gồm cách tf.function cho phép bạn chuyển từ thực thi háo hức sang thực thi đồ thị. Để có thông số kỹ thuật đầy đủ hơn về tf.function , hãy truy cập hướng dẫn tf.function .

Đồ thị là gì?

Trong ba hướng dẫn trước, bạn đã chạy TensorFlow một cách háo hức . Điều này có nghĩa là các hoạt động TensorFlow được thực thi bởi Python, hoạt động theo hoạt động và trả kết quả trở lại Python.

Trong khi thực thi háo hức có một số lợi thế riêng, thực thi đồ thị cho phép tính di động bên ngoài Python và có xu hướng cung cấp hiệu suất tốt hơn. Thực thi đồ thị có nghĩa là các phép tính tensor được thực hiện dưới dạng đồ thị TensorFlow , đôi khi được gọi là tf.Graph hoặc đơn giản là "đồ thị".

Đồ thị là cấu trúc dữ liệu chứa một tập hợp các đối tượng tf.Operation , đại diện cho các đơn vị tính toán; và các đối tượng tf.Tensor , đại diện cho các đơn vị dữ liệu lưu chuyển giữa các hoạt động. Chúng được định nghĩa trong ngữ cảnh tf.Graph . Vì những đồ thị này là cấu trúc dữ liệu, chúng có thể được lưu, chạy và khôi phục tất cả mà không cần mã Python gốc.

Đây là biểu đồ TensorFlow đại diện cho mạng nơ-ron hai lớp trông như thế nào khi được hiển thị trong TensorBoard.

Một biểu đồ TensorFlow đơn giản

Lợi ích của đồ thị

Với một biểu đồ, bạn có rất nhiều tính linh hoạt. Bạn có thể sử dụng biểu đồ TensorFlow của mình trong các môi trường không có trình thông dịch Python, như các ứng dụng di động, thiết bị nhúng và máy chủ phụ trợ. TensorFlow sử dụng đồ thị làm định dạng cho các mô hình đã lưu khi xuất chúng từ Python.

Đồ thị cũng được tối ưu hóa dễ dàng, cho phép trình biên dịch thực hiện các phép biến đổi như:

  • Suy ra tĩnh giá trị của tensors bằng cách gấp các nút không đổi trong tính toán của bạn ("gấp liên tục") .
  • Tách các phần con của một phép tính độc lập và chia chúng giữa các luồng hoặc thiết bị.
  • Đơn giản hóa các phép toán số học bằng cách loại bỏ các biểu thức con thông thường.

Có toàn bộ hệ thống tối ưu hóa, Grappler , để thực hiện việc này và các cách tăng tốc khác.

Tóm lại, đồ thị cực kỳ hữu ích và cho phép TensorFlow của bạn chạy nhanh , chạy song song và chạy hiệu quả trên nhiều thiết bị .

Tuy nhiên, bạn vẫn muốn xác định các mô hình học máy (hoặc các phép tính khác) của mình bằng Python để thuận tiện và sau đó tự động xây dựng đồ thị khi bạn cần.

Thành lập

import tensorflow as tf
import timeit
from datetime import datetime

Tận dụng đồ thị

Bạn tạo và chạy một biểu đồ trong TensorFlow bằng cách sử dụng tf.function . function, dưới dạng một cuộc gọi trực tiếp hoặc như một người trang trí. tf.function nhận một hàm thông thường làm đầu vào và trả về một Function . Function là một hàm có thể gọi trong Python để xây dựng đồ thị TensorFlow từ hàm Python. Bạn sử dụng một Function theo cách tương tự với hàm tương đương trong Python của nó.

# Define a Python function.
def a_regular_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# `a_function_that_uses_a_graph` is a TensorFlow `Function`.
a_function_that_uses_a_graph = tf.function(a_regular_function)

# Make some tensors.
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)

orig_value = a_regular_function(x1, y1, b1).numpy()
# Call a `Function` like a Python function.
tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()
assert(orig_value == tf_function_value)

Nhìn bên ngoài, một Function trông giống như một hàm thông thường mà bạn viết bằng các phép toán TensorFlow. Bên dưới , tuy nhiên, nó rất khác . Một Function đóng gói một số tf.Graph phía sau một API . Đó là cách Function có thể cung cấp cho bạn những lợi ích của việc thực thi đồ thị , chẳng hạn như tốc độ và khả năng triển khai.

tf.function áp dụng cho một hàm và tất cả các hàm khác mà nó gọi :

def inner_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# Use the decorator to make `outer_function` a `Function`.
@tf.function
def outer_function(x):
  y = tf.constant([[2.0], [3.0]])
  b = tf.constant(4.0)

  return inner_function(x, y, b)

# Note that the callable will create a graph that
# includes `inner_function` as well as `outer_function`.
outer_function(tf.constant([[1.0, 2.0]])).numpy()
array([[12.]], dtype=float32)

Nếu bạn đã sử dụng TensorFlow 1.x, bạn sẽ nhận thấy rằng bạn không cần phải xác định Placeholder hoặc tf.Session tại bất kỳ thời điểm nào.

Chuyển đổi các hàm Python thành đồ thị

Bất kỳ hàm nào bạn viết bằng TensorFlow sẽ chứa hỗn hợp các hoạt động TF tích hợp sẵn và logic Python, chẳng hạn như mệnh đề if-then , vòng lặp, break , return , continue và hơn thế nữa. Trong khi các hoạt động TensorFlow dễ dàng được nắm bắt bởi một tf.Graph , logic cụ thể của Python cần phải trải qua một bước bổ sung để trở thành một phần của biểu đồ. tf.function sử dụng một thư viện có tên là AutoGraph ( tf.autograph ) để chuyển đổi mã Python thành mã tạo đồ thị.

def simple_relu(x):
  if tf.greater(x, 0):
    return x
  else:
    return 0

# `tf_simple_relu` is a TensorFlow `Function` that wraps `simple_relu`.
tf_simple_relu = tf.function(simple_relu)

print("First branch, with graph:", tf_simple_relu(tf.constant(1)).numpy())
print("Second branch, with graph:", tf_simple_relu(tf.constant(-1)).numpy())
First branch, with graph: 1
Second branch, with graph: 0

Mặc dù không chắc rằng bạn sẽ cần xem đồ thị trực tiếp, nhưng bạn có thể kiểm tra kết quả đầu ra để kiểm tra kết quả chính xác. Những điều này không dễ đọc, vì vậy không cần phải xem xét quá kỹ lưỡng!

# This is the graph-generating output of AutoGraph.
print(tf.autograph.to_code(simple_relu))
def tf__simple_relu(x):
    with ag__.FunctionScope('simple_relu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (do_return, retval_)

        def set_state(vars_):
            nonlocal retval_, do_return
            (do_return, retval_) = vars_

        def if_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = ag__.ld(x)
            except:
                do_return = False
                raise

        def else_body():
            nonlocal retval_, do_return
            try:
                do_return = True
                retval_ = 0
            except:
                do_return = False
                raise
        ag__.if_stmt(ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(x), 0), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
        return fscope.ret(retval_, do_return)
# This is the graph itself.
print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())
node {
  name: "x"
  op: "Placeholder"
  attr {
    key: "_user_specified_name"
    value {
      s: "x"
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "Greater/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
}
node {
  name: "Greater"
  op: "Greater"
  input: "x"
  input: "Greater/y"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "cond"
  op: "StatelessIf"
  input: "Greater"
  input: "x"
  attr {
    key: "Tcond"
    value {
      type: DT_BOOL
    }
  }
  attr {
    key: "Tin"
    value {
      list {
        type: DT_INT32
      }
    }
  }
  attr {
    key: "Tout"
    value {
      list {
        type: DT_BOOL
        type: DT_INT32
      }
    }
  }
  attr {
    key: "_lower_using_switch_merge"
    value {
      b: true
    }
  }
  attr {
    key: "_read_only_resource_inputs"
    value {
      list {
      }
    }
  }
  attr {
    key: "else_branch"
    value {
      func {
        name: "cond_false_34"
      }
    }
  }
  attr {
    key: "output_shapes"
    value {
      list {
        shape {
        }
        shape {
        }
      }
    }
  }
  attr {
    key: "then_branch"
    value {
      func {
        name: "cond_true_33"
      }
    }
  }
}
node {
  name: "cond/Identity"
  op: "Identity"
  input: "cond"
  attr {
    key: "T"
    value {
      type: DT_BOOL
    }
  }
}
node {
  name: "cond/Identity_1"
  op: "Identity"
  input: "cond:1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "Identity"
  op: "Identity"
  input: "cond/Identity_1"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
library {
  function {
    signature {
      name: "cond_false_34"
      input_arg {
        name: "cond_placeholder"
        type: DT_INT32
      }
      output_arg {
        name: "cond_identity"
        type: DT_BOOL
      }
      output_arg {
        name: "cond_identity_1"
        type: DT_INT32
      }
    }
    node_def {
      name: "cond/Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Const_1"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Const_2"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 0
          }
        }
      }
    }
    node_def {
      name: "cond/Const_3"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Identity"
      op: "Identity"
      input: "cond/Const_3:output:0"
      attr {
        key: "T"
        value {
          type: DT_BOOL
        }
      }
    }
    node_def {
      name: "cond/Const_4"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_INT32
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_INT32
            tensor_shape {
            }
            int_val: 0
          }
        }
      }
    }
    node_def {
      name: "cond/Identity_1"
      op: "Identity"
      input: "cond/Const_4:output:0"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
    }
    ret {
      key: "cond_identity"
      value: "cond/Identity:output:0"
    }
    ret {
      key: "cond_identity_1"
      value: "cond/Identity_1:output:0"
    }
    attr {
      key: "_construction_context"
      value {
        s: "kEagerRuntime"
      }
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_output_shapes"
          value {
            list {
              shape {
              }
            }
          }
        }
      }
    }
  }
  function {
    signature {
      name: "cond_true_33"
      input_arg {
        name: "cond_identity_1_x"
        type: DT_INT32
      }
      output_arg {
        name: "cond_identity"
        type: DT_BOOL
      }
      output_arg {
        name: "cond_identity_1"
        type: DT_INT32
      }
    }
    node_def {
      name: "cond/Const"
      op: "Const"
      attr {
        key: "dtype"
        value {
          type: DT_BOOL
        }
      }
      attr {
        key: "value"
        value {
          tensor {
            dtype: DT_BOOL
            tensor_shape {
            }
            bool_val: true
          }
        }
      }
    }
    node_def {
      name: "cond/Identity"
      op: "Identity"
      input: "cond/Const:output:0"
      attr {
        key: "T"
        value {
          type: DT_BOOL
        }
      }
    }
    node_def {
      name: "cond/Identity_1"
      op: "Identity"
      input: "cond_identity_1_x"
      attr {
        key: "T"
        value {
          type: DT_INT32
        }
      }
    }
    ret {
      key: "cond_identity"
      value: "cond/Identity:output:0"
    }
    ret {
      key: "cond_identity_1"
      value: "cond/Identity_1:output:0"
    }
    attr {
      key: "_construction_context"
      value {
        s: "kEagerRuntime"
      }
    }
    arg_attr {
      key: 0
      value {
        attr {
          key: "_output_shapes"
          value {
            list {
              shape {
              }
            }
          }
        }
      }
    }
  }
}
versions {
  producer: 898
  min_consumer: 12
}

Hầu hết thời gian, tf.function sẽ hoạt động mà không cần cân nhắc đặc biệt. Tuy nhiên, có một số lưu ý và hướng dẫn tf. Chức năng có thể trợ giúp ở đây, cũng như tài liệu tham khảo đầy đủ về AutoGraph

Đa hình: một Function , nhiều đồ thị

Một tf.Graph chuyên biệt cho một loại đầu vào cụ thể (ví dụ: tenxơ có kiểu dtype cụ thể hoặc các đối tượng có cùng id() ).

Mỗi khi bạn gọi một Function với các kiểu và hình dạng mới trong các đối số của nó, Function sẽ tạo một dtypes tf.Graph cho các đối số mới. Các kiểu và hình dạng của tf.Graph dtypes gọi là chữ ký đầu vào hoặc chỉ là một chữ ký .

Function lưu trữ tf.Graph tương ứng với chữ ký đó trong ConcreteFunction . ConcreteFunction là một trình bao bọc xung quanh một tf.Graph .

@tf.function
def my_relu(x):
  return tf.maximum(0., x)

# `my_relu` creates new graphs as it observes more signatures.
print(my_relu(tf.constant(5.5)))
print(my_relu([1, -1]))
print(my_relu(tf.constant([3., -3.])))
tf.Tensor(5.5, shape=(), dtype=float32)
tf.Tensor([1. 0.], shape=(2,), dtype=float32)
tf.Tensor([3. 0.], shape=(2,), dtype=float32)

Nếu Function đã được gọi với chữ ký đó, thì Function không tạo một tf.Graph mới.

# These two calls do *not* create new graphs.
print(my_relu(tf.constant(-2.5))) # Signature matches `tf.constant(5.5)`.
print(my_relu(tf.constant([-1., 1.]))) # Signature matches `tf.constant([3., -3.])`.
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor([0. 1.], shape=(2,), dtype=float32)

Bởi vì nó được hỗ trợ bởi nhiều đồ thị, một Functionđa hình . Điều đó cho phép nó hỗ trợ nhiều loại đầu vào hơn so với một tf.Graph duy nhất có thể đại diện, cũng như tối ưu hóa từng tf.Graph để có hiệu suất tốt hơn.

# There are three `ConcreteFunction`s (one for each graph) in `my_relu`.
# The `ConcreteFunction` also knows the return type and shape!
print(my_relu.pretty_printed_concrete_signatures())
my_relu(x)
  Args:
    x: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

my_relu(x=[1, -1])
  Returns:
    float32 Tensor, shape=(2,)

my_relu(x)
  Args:
    x: float32 Tensor, shape=(2,)
  Returns:
    float32 Tensor, shape=(2,)

Sử dụng tf.function

Cho đến nay, bạn đã học cách chuyển đổi một hàm Python thành một đồ thị đơn giản bằng cách sử dụng tf.function làm trình trang trí hoặc trình bao bọc. Nhưng trên thực tế, để tf.function hoạt động chính xác có thể rất khó! Trong các phần sau, bạn sẽ tìm hiểu cách làm cho mã của mình hoạt động như mong đợi với tf.function . function.

Thực thi đồ thị so với thực thi háo hức

Mã trong một Function có thể được thực thi cả một cách háo hức và dưới dạng đồ thị. Theo mặc định, Function thực thi mã của nó dưới dạng đồ thị:

@tf.function
def get_MSE(y_true, y_pred):
  sq_diff = tf.pow(y_true - y_pred, 2)
  return tf.reduce_mean(sq_diff)
y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)
y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)
print(y_true)
print(y_pred)
tf.Tensor([1 0 4 4 7], shape=(5,), dtype=int32)
tf.Tensor([3 6 3 0 6], shape=(5,), dtype=int32)
get_MSE(y_true, y_pred)
<tf.Tensor: shape=(), dtype=int32, numpy=11>

Để xác minh rằng đồ thị của Function của bạn đang thực hiện tính toán giống như hàm Python tương đương của nó, bạn có thể làm cho nó thực thi một cách háo hức với tf.config.run_functions_eagerly(True) . Đây là một công tắc tắt khả năng tạo và chạy đồ thị của Function , thay vào đó thực thi mã một cách bình thường.

tf.config.run_functions_eagerly(True)
get_MSE(y_true, y_pred)
<tf.Tensor: shape=(), dtype=int32, numpy=11>
# Don't forget to set it back when you are done.
tf.config.run_functions_eagerly(False)

Tuy nhiên, Function có thể hoạt động khác nhau dưới biểu đồ và thực thi mong muốn. Hàm print trong Python là một ví dụ về sự khác biệt của hai chế độ này. Hãy kiểm tra điều gì sẽ xảy ra khi bạn chèn một câu lệnh print vào hàm của mình và gọi nó nhiều lần.

@tf.function
def get_MSE(y_true, y_pred):
  print("Calculating MSE!")
  sq_diff = tf.pow(y_true - y_pred, 2)
  return tf.reduce_mean(sq_diff)

Quan sát những gì được in:

error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
Calculating MSE!

Đầu ra có đáng ngạc nhiên không? get_MSE chỉ được in một lần mặc dù nó được gọi ba lần.

Để giải thích, câu lệnh print được thực thi khi Function chạy mã gốc để tạo đồ thị trong một quá trình được gọi là "truy tìm" . Theo dõi ghi lại các hoạt động TensorFlow vào một biểu đồ và print không được ghi lại trong biểu đồ. Biểu đồ đó sau đó được thực thi cho cả ba lệnh gọi mà không bao giờ chạy lại mã Python .

Để kiểm tra sự tỉnh táo, hãy tắt thực thi đồ thị để so sánh:

# Now, globally set everything to run eagerly to force eager execution.
tf.config.run_functions_eagerly(True)
# Observe what is printed below.
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
Calculating MSE!
Calculating MSE!
Calculating MSE!
tf.config.run_functions_eagerly(False)

print là một tác dụng phụ của Python và có những khác biệt khác mà bạn nên biết khi chuyển đổi một hàm thành một Function . Tìm hiểu thêm trong phần Hạn chế của hướng dẫn Hiệu suất tốt hơn với tf . Chức năng.

Thực hiện không nghiêm ngặt

Thực thi đồ thị chỉ thực hiện các hoạt động cần thiết để tạo ra các hiệu ứng có thể quan sát được, bao gồm:

  • Giá trị trả về của hàm
  • Các tác dụng phụ nổi tiếng đã được ghi nhận như:
    • Các hoạt động đầu vào / đầu ra, như tf.print
    • Các hoạt động gỡ lỗi, chẳng hạn như các chức năng khẳng định trong tf.debugging
    • Đột biến của tf.Variable

Hành vi này thường được gọi là "Thực thi không nghiêm ngặt" và khác với thực thi háo hức, thực hiện các bước thông qua tất cả các hoạt động của chương trình, cần thiết hoặc không.

Đặc biệt, kiểm tra lỗi thời gian chạy không được tính là một hiệu ứng có thể quan sát được. Nếu một thao tác bị bỏ qua vì không cần thiết, nó không thể phát sinh bất kỳ lỗi thời gian chạy nào.

Trong ví dụ sau, thao tác "không cần thiết" tf.gather bị bỏ qua trong quá trình thực thi đồ thị, do đó, lỗi thời gian chạy không hợp InvalidArgumentError không xuất hiện như khi thực thi háo hức. Đừng dựa vào một lỗi được đưa ra trong khi thực hiện một biểu đồ.

def unused_return_eager(x):
  # Get index 1 will fail when `len(x) == 1`
  tf.gather(x, [1]) # unused 
  return x

try:
  print(unused_return_eager(tf.constant([0.0])))
except tf.errors.InvalidArgumentError as e:
  # All operations are run during eager execution so an error is raised.
  print(f'{type(e).__name__}: {e}')
tf.Tensor([0.], shape=(1,), dtype=float32)
@tf.function
def unused_return_graph(x):
  tf.gather(x, [1]) # unused
  return x

# Only needed operations are run during graph exection. The error is not raised.
print(unused_return_graph(tf.constant([0.0])))
tf.Tensor([0.], shape=(1,), dtype=float32)

các phương pháp hay nhất tf.function

Có thể mất một thời gian để làm quen với hoạt động của Function . Để bắt đầu nhanh chóng, người dùng lần đầu tiên nên thử trang trí các chức năng đồ chơi với @tf.function để có kinh nghiệm từ việc thực hiện đồ chơi đến háo hức.

Thiết kế cho tf.function có thể là cách tốt nhất để bạn viết các chương trình TensorFlow tương thích với đồ thị. Dưới đây là một số mẹo:

  • Chuyển đổi giữa việc thực thi đồ thị và háo hức sớm và thường xuyên với tf.config.run_functions_eagerly để xác định nếu / khi hai chế độ khác nhau.
  • Tạo tf.Variable bên ngoài hàm Python và sửa đổi chúng ở bên trong. Tương tự đối với các đối tượng sử dụng tf.Variable , như keras.layers , keras.Model s và tf.optimizers .
  • Tránh viết các hàm phụ thuộc vào các biến bên ngoài Python , ngoại trừ các đối tượng tf.Variable s và Keras.
  • Thích viết các hàm lấy tensor và các loại TensorFlow khác làm đầu vào. Bạn có thể vượt qua các loại đối tượng khác nhưng hãy cẩn thận !
  • Bao gồm nhiều tính toán nhất có thể trong một tf.function để tối đa hóa mức tăng hiệu suất. Ví dụ, trang trí toàn bộ bước đào tạo hoặc toàn bộ vòng đào tạo.

Thấy tốc độ tăng

tf.function thường cải thiện hiệu suất mã của bạn, nhưng tốc độ tăng tốc phụ thuộc vào loại tính toán bạn chạy. Các phép tính nhỏ có thể bị chi phối bởi chi phí gọi một đồ thị. Bạn có thể đo lường sự khác biệt về hiệu suất như sau:

x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)

def power(x, y):
  result = tf.eye(10, dtype=tf.dtypes.int32)
  for _ in range(y):
    result = tf.matmul(x, result)
  return result
print("Eager execution:", timeit.timeit(lambda: power(x, 100), number=1000))
Eager execution: 2.5637862179974036
power_as_graph = tf.function(power)
print("Graph execution:", timeit.timeit(lambda: power_as_graph(x, 100), number=1000))
Graph execution: 0.6832536700021592

tf.function thường được sử dụng để tăng tốc các vòng huấn luyện và bạn có thể tìm hiểu thêm về nó trong Viết một vòng huấn luyện từ đầu với Keras.

Hiệu suất và sự đánh đổi

Đồ thị có thể tăng tốc mã của bạn, nhưng quá trình tạo chúng có một số chi phí. Đối với một số hàm, việc tạo đồ thị mất nhiều thời gian hơn so với việc thực hiện đồ thị. Khoản đầu tư này thường nhanh chóng được hoàn vốn với việc tăng hiệu suất của các lần thực hiện tiếp theo, nhưng điều quan trọng cần lưu ý là các bước đầu tiên của bất kỳ quá trình đào tạo mô hình lớn nào có thể chậm hơn do quá trình theo dõi.

Bất kể mô hình của bạn lớn đến mức nào, bạn cũng muốn tránh phải theo dõi thường xuyên. Hướng dẫn tf.function thảo luận về cách đặt thông số kỹ thuật đầu vào và sử dụng các đối số tensor để tránh chạy lại. Nếu bạn nhận thấy mình đang có hiệu suất kém bất thường, bạn nên kiểm tra xem bạn có vô tình thử lại hay không.

Khi nào thì một Function theo dõi?

Để biết khi nào Function của bạn đang theo dõi, hãy thêm một câu lệnh print vào mã của nó. Theo nguyên tắc chung, Function sẽ thực hiện câu lệnh print mỗi khi nó theo dõi.

@tf.function
def a_function_with_python_side_effect(x):
  print("Tracing!") # An eager-only side effect.
  return x * x + tf.constant(2)

# This is traced the first time.
print(a_function_with_python_side_effect(tf.constant(2)))
# The second time through, you won't see the side effect.
print(a_function_with_python_side_effect(tf.constant(3)))
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)
# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter.
print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
Tracing!
tf.Tensor(11, shape=(), dtype=int32)

Các đối số Python mới luôn kích hoạt việc tạo một biểu đồ mới, do đó cần phải theo dõi thêm.

Bước tiếp theo

Bạn có thể tìm hiểu thêm về tf.function trên trang tham chiếu API và làm theo hướng dẫn Hiệu suất tốt hơn với tf.function . Chức năng.