Halaman ini diterjemahkan oleh Cloud Translation API.
Switch to English

Performa yang lebih baik dengan fungsi tf.

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Di TensorFlow 2, eksekusi yang cepat diaktifkan secara default. Antarmuka pengguna intuitif dan fleksibel (menjalankan operasi satu kali lebih mudah dan lebih cepat), tetapi ini bisa mengorbankan kinerja dan penerapan.

Anda dapat menggunakan tf.function untuk membuat grafik dari program Anda. Ini adalah alat transformasi yang membuat grafik aliran data Python-independen dari kode Python Anda. Ini akan membantu Anda membuat model berkinerja dan portabel, dan diperlukan untuk menggunakan SavedModel .

Panduan ini akan membantu Anda membuat konsep bagaimana tf.function berfungsi di bawah tenda sehingga Anda dapat menggunakannya secara efektif.

Rekomendasi utama dan rekomendasi adalah:

  • Debug dalam mode eager, lalu hiasi dengan @tf.function .
  • Jangan mengandalkan efek samping Python seperti mutasi objek atau menambahkan daftar.
  • tf.function bekerja paling baik dengan operasi TensorFlow; Panggilan NumPy dan Python dikonversi ke konstanta.

Mempersiapkan

 import tensorflow as tf
 

Tetapkan fungsi pembantu untuk menunjukkan jenis kesalahan yang mungkin Anda temui:

 import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))
 

Dasar-dasar

Pemakaian

Function Anda tetapkan sama seperti operasi inti TensorFlow: Anda dapat menjalankannya dengan penuh semangat; Anda dapat menghitung gradien; dan seterusnya.

 @tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
 
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
 v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
 
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Anda dapat menggunakan Function di dalam Function lainnya.

 @tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
 
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function s bisa lebih cepat daripada kode yang diinginkan, terutama untuk grafik dengan banyak ops kecil. Tetapi untuk grafik dengan beberapa ops mahal (seperti konvolusi), Anda mungkin tidak melihat banyak peningkatan.

 import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

 
Eager conv: 0.0023194860004878137
Function conv: 0.0036776439992536325
Note how there's not much difference in performance for convolutions

Menelusuri

Pengetikan dinamis Python berarti Anda dapat memanggil fungsi dengan berbagai jenis argumen, dan Python dapat melakukan sesuatu yang berbeda di setiap skenario.

Namun, untuk membuat Grafik TensorFlow, dtypes statis dan dimensi bentuk. tf.function menjembatani kesenjangan ini dengan membungkus fungsi Python untuk membuat objek Function . Berdasarkan input yang diberikan, Function memilih grafik yang sesuai untuk input yang diberikan, menelusuri kembali fungsi Python seperlunya. Setelah Anda memahami mengapa dan ketika penelusuran terjadi, jauh lebih mudah untuk menggunakan tf.function secara efektif!

Anda dapat memanggil suatu Function dengan argumen dari tipe yang berbeda untuk melihat perilaku polimorfik ini bekerja.

 @tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()

 
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)


Perhatikan bahwa jika Anda berulang kali memanggil suatu Function dengan tipe argumen yang sama, TensorFlow akan menggunakan kembali grafik yang dilacak sebelumnya, karena grafik yang dihasilkan akan sama.

 # This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
 
tf.Tensor(b'bb', shape=(), dtype=string)

(Perubahan berikut tersedia di TensorFlow setiap malam, dan akan tersedia di TensorFlow 2.3.)

Anda dapat menggunakan pretty_printed_concrete_signatures() untuk melihat semua jejak yang tersedia:

 print(double.pretty_printed_concrete_signatures())
 
double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

Sejauh ini, Anda telah melihat bahwa tf.function menciptakan cache, lapisan pengiriman dinamis di atas logika penelusuran grafik TensorFlow. Untuk lebih spesifik tentang terminologi:

  • tf.Graph adalah representasi portabel, bahasa-agnostik, portable dari komputasi Anda.
  • A ConcreteFunction adalah pembungkus yang bersemangat menjalankan sekitar tf.Graph .
  • A Function mengelola cache dari ConcreteFunction s dan memilih yang tepat untuk input Anda.
  • tf.function membungkus fungsi Python, mengembalikan objek Function .

Memperoleh fungsi konkret

Setiap kali fungsi dilacak, fungsi konkret baru dibuat. Anda dapat langsung mendapatkan fungsi konkret, dengan menggunakan get_concrete_function .

 print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))

 
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)

 # You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
 
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'cc', shape=(), dtype=string)

(Perubahan berikut tersedia di TensorFlow setiap malam, dan akan tersedia di TensorFlow 2.3.)

Mencetak ConcreteFunction menampilkan ringkasan argumen inputnya (dengan tipe) dan tipe outputnya.

 print(double_strings)
 
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Anda juga dapat langsung mengambil tanda tangan fungsi konkret.

 print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
 
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

Menggunakan jejak konkret dengan tipe yang tidak kompatibel akan menimbulkan kesalahan

 with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
 
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-15-e4e2860a4364>", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_168 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_168]

Anda mungkin memperhatikan bahwa argumen Python diberikan perlakuan khusus dalam tanda tangan input fungsi konkret. Sebelum TensorFlow 2.3, argumen Python hanya dihapus dari tanda tangan fungsi konkret. Dimulai dengan TensorFlow 2.3, argumen Python tetap di tanda tangan, tetapi dibatasi untuk mengambil nilai yang ditetapkan selama pelacakan.

 @tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
 
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>

 assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
 
Caught expected exception 
  <class 'TypeError'>:

Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1669, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1714, in _call_with_flat_signature
    self._flat_signature_summary(), ", ".join(sorted(kwargs))))
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-17-d163f3d206cb>", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3

Memperoleh grafik

Setiap fungsi konkret adalah bungkus callable di sekitar tf.Graph . Meskipun mengambil objek tf.Graph sebenarnya bukanlah sesuatu yang biasanya perlu Anda lakukan, Anda dapat memperolehnya dengan mudah dari fungsi konkret apa pun.

 graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')

 
[] -> a
['a', 'a'] -> add
['add'] -> Identity

Debugging

Secara umum, kode debug lebih mudah dalam mode eager daripada di dalam tf.function . Anda harus memastikan bahwa kode Anda mengeksekusi bebas kesalahan dalam mode eager sebelum menghias dengan tf.function . Untuk membantu dalam proses debug, Anda dapat memanggil tf.config.run_functions_eagerly(True) untuk menonaktifkan dan mengaktifkan kembali secara global tf.function .

Saat melacak masalah yang hanya muncul dalam tf.function , berikut adalah beberapa kiat:

  • Panggilan print Python lama biasa hanya dieksekusi selama pelacakan, membantu Anda melacak ketika fungsi Anda (kembali) dilacak.
  • tf.print panggilan akan dieksekusi setiap waktu, dan dapat membantu Anda melacak nilai-nilai perantara selama eksekusi.
  • tf.debugging.enable_check_numerics adalah cara mudah untuk melacak di mana NaNs dan Inf dibuat.
  • pdb dapat membantu Anda memahami apa yang terjadi selama penelusuran. (Peringatan: PDB akan menurunkan Anda ke kode sumber yang diubah AutoGraph.)

Melacak semantik

Aturan kunci cache

Suatu Function menentukan apakah akan menggunakan kembali fungsi beton yang dilacak dengan menghitung kunci cache dari arg dan kwarg input.

  • Kunci yang dihasilkan untuk argumen tf.Tensor adalah bentuk dan jenisnya.
  • Mulai di TensorFlow 2.3, kunci yang dihasilkan untuk argumen tf.Variable adalah id() .
  • Kunci yang dihasilkan untuk primitif Python adalah nilainya. Kunci yang dihasilkan untuk dict bersarang, list s, tuple s, namedtuple s, dan attr s adalah tuple pipih. (Sebagai hasil dari perataan ini, memanggil fungsi beton dengan struktur bersarang yang berbeda dari yang digunakan selama pelacakan akan menghasilkan TypeError).
  • Untuk semua jenis Python lainnya, kunci didasarkan pada objek id() sehingga metode dilacak secara independen untuk setiap instance kelas.

Mengontrol pelacakan

Retracing membantu memastikan bahwa TensorFlow menghasilkan grafik yang benar untuk setiap rangkaian input. Namun, pelacakan adalah operasi yang mahal! Jika Function Anda menelusuri kembali grafik baru untuk setiap panggilan, Anda akan menemukan bahwa kode Anda dieksekusi lebih lambat daripada jika Anda tidak menggunakan tf.function .

Untuk mengontrol perilaku penelusuran, Anda dapat menggunakan teknik berikut:

  • Tentukan input_signature di tf.function untuk membatasi pelacakan.
 @tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# We specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))

 
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-19-20f544b8adbf>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-19-20f544b8adbf>", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))

  • Tentukan dimensi [Tidak Ada] di tf.TensorSpec untuk memungkinkan fleksibilitas dalam penggunaan ulang jejak.

    Karena TensorFlow mencocokkan tensor berdasarkan bentuknya, menggunakan dimensi None sebagai wildcard akan memungkinkan Function s untuk menggunakan kembali jejak untuk input berukuran bervariasi. Input berukuran variabel dapat terjadi jika Anda memiliki urutan panjang yang berbeda, atau gambar dengan ukuran berbeda untuk setiap batch (Lihat tutorial Transformer dan Deep Dream misalnya).

 @tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))

 
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)

  • Keluarkan argumen Python ke Tensor untuk mengurangi penelusuran.

    Seringkali, argumen Python digunakan untuk mengontrol hyperparameters dan konstruksi grafik - misalnya, num_layers=10 atau training=True atau nonlinearity='relu' . Jadi jika argumen Python berubah, masuk akal bahwa Anda harus menelusuri kembali grafik.

    Namun, ada kemungkinan bahwa argumen Python tidak digunakan untuk mengontrol konstruksi grafik. Dalam kasus ini, perubahan nilai Python dapat memicu penelusuran ulang yang tidak perlu. Ambil, misalnya, loop pelatihan ini, yang AutoGraph akan buka secara dinamis. Meskipun ada banyak jejak, grafik yang dihasilkan sebenarnya identik, sehingga penelusuran ulang tidak perlu.

 def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
 
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

Jika Anda perlu memaksakan pelacakan, buat sebuah Function baru. Objek Function terpisah dijamin tidak akan berbagi jejak.

 def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
 
Tracing!
Executing
Tracing!
Executing

Efek samping python

Efek samping python seperti mencetak, menambahkan ke daftar, dan mengubah global hanya terjadi pertama kali Anda memanggil suatu Function dengan satu set input. Setelah itu, tf.Graph dilacak tf.Graph kembali, tanpa mengeksekusi kode Python.

Aturan umum adalah untuk hanya menggunakan efek samping Python untuk men-debug jejak Anda. Jika tidak, operasi TensorFlow seperti tf.Variable.assign , tf.print , dan tf.summary adalah cara terbaik untuk memastikan kode Anda akan dilacak dan dieksekusi oleh runtime TensorFlow dengan setiap panggilan.

 @tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)

 
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

Banyak fitur Python, seperti generator dan iterator, bergantung pada runtime Python untuk melacak keadaan. Secara umum, sementara konstruksi ini berfungsi seperti yang diharapkan dalam mode bersemangat, banyak hal tak terduga dapat terjadi di dalam suatu Function .

Untuk memberikan satu contoh, memajukan status iterator adalah efek samping Python dan karenanya hanya terjadi selama pelacakan.

 external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)

 
Value of external_var: 0
Value of external_var: 0
Value of external_var: 0

Beberapa konstruksi iterasi didukung melalui AutoGraph. Lihat bagian tentang AutoGraph Transformations untuk ikhtisar.

Jika Anda ingin mengeksekusi kode Python selama setiap pemanggilan suatu Function , tf.py_function adalah pintu keluar. Kelemahan dari tf.py_function adalah tidak portabel atau berkinerja tinggi, juga tidak berfungsi dengan baik dalam pengaturan terdistribusi (multi-GPU, TPU). Juga, karena tf.py_function harus ditransfer ke dalam grafik, tf.cips semua input / output ke tensor.

API seperti tf.gather , tf.stack , dan tf.TensorArray dapat membantu Anda menerapkan pola perulangan umum di TensorFlow asli.

 external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
# The list append happens all three times!
assert len(external_list) == 3
# The list contains tf.constant(1), not 1, because py_function casts everything to tensors.
assert external_list[0].numpy() == 1

 
Python side effect
Python side effect
Python side effect

Variabel

Anda mungkin mengalami kesalahan saat membuat tf.Variable baru. tf.Variable dalam suatu fungsi. Kesalahan ini menjaga terhadap perbedaan perilaku pada panggilan berulang: Dalam mode eager, suatu fungsi membuat variabel baru dengan setiap panggilan, tetapi dalam suatu Function , variabel baru tidak dapat dibuat karena melacak penggunaan kembali.

 @tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)
 
Caught expected exception 
  <class 'ValueError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-26-73e410646579>", line 8, in <module>
    f(1.0)
ValueError: in user code:

    <ipython-input-26-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:702 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.


Anda bisa membuat variabel di dalam suatu Function selama variabel itu hanya dibuat saat pertama kali fungsi dijalankan.

 class Count(tf.Module):
  def __init__(self):
    self.count = None
  
  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
 
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Kesalahan lain yang mungkin Anda temui adalah variabel sampah yang dikumpulkan. Tidak seperti fungsi Python normal, fungsi konkret hanya mempertahankan WeakRefs ke variabel yang ditutup, sehingga Anda harus mempertahankan referensi ke variabel apa pun.

 external_var = tf.Variable(3)
@tf.function
def f(x):
  return x * external_var

traced_f = f.get_concrete_function(4)
print("Calling concrete function...")
print(traced_f(4))

del external_var
print()
print("Calling concrete function after garbage collecting its closed Variable...")
with assert_raises(tf.errors.FailedPreconditionError):
  traced_f(4)
 
Calling concrete function...
tf.Tensor(12, shape=(), dtype=int32)

Calling concrete function after garbage collecting its closed Variable...
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.FailedPreconditionError'>:

Traceback (most recent call last):
  File "<ipython-input-3-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-28-304a18524b57>", line 14, in <module>
    traced_f(4)
tensorflow.python.framework.errors_impl.FailedPreconditionError: 2 root error(s) found.
  (0) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-28-304a18524b57>:4) ]]
     [[ReadVariableOp/_2]]
  (1) Failed precondition:  Error while reading resource variable _AnonymousVar4 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar4/N10tensorflow3VarE does not exist.
     [[node ReadVariableOp (defined at <ipython-input-28-304a18524b57>:4) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference_f_514]

Function call stack:
f -> f


Transformasi AutoGraph

AutoGraph adalah pustaka yang aktif secara default di tf.function , dan mentransformasikan subset kode Python eager menjadi ops TensorFlow yang kompatibel dengan grafik. Ini termasuk aliran kontrol seperti if , for , while .

Operasi TensorFlow seperti tf.cond dan tf.while_loop terus bekerja, tetapi aliran kontrol seringkali lebih mudah untuk ditulis dan dipahami ketika ditulis dengan Python.

 # Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
 
[0.448926926 0.896036148 0.703306437 0.446930766 0.20440042]
[0.421016544 0.714362323 0.6064623 0.419372857 0.201600626]
[0.397786468 0.613405049 0.541632056 0.396401972 0.198913112]
[0.378053397 0.546519518 0.494222373 0.376866162 0.196330562]
[0.361015767 0.497907132 0.457561225 0.359982818 0.1938463]
[0.346108437 0.460469633 0.428094476 0.3451989 0.191454232]
[0.332919776 0.43046692 0.403727621 0.332110822 0.189148799]
[0.321141869 0.405711472 0.383133948 0.320416152 0.18692489]
[0.310539037 0.384825289 0.365426034 0.309883147 0.184777796]
[0.300927401 0.366890609 0.349984437 0.300330788 0.182703182]
[0.292161077 0.351268977 0.336361736 0.291615278 0.180697069]
[0.284122646 0.337500453 0.324225426 0.283620834 0.178755745]
[0.276716352 0.325244069 0.313322544 0.276252925 0.176875815]
[0.269863278 0.314240903 0.303456694 0.269433528 0.175054088]
[0.263497591 0.304290265 0.294472754 0.263097644 0.17328763]
[0.257564 0.295233846 0.2862463 0.257190555 0.171573699]
[0.25201565 0.286944896 0.278676242 0.25166589 0.169909731]
[0.246812463 0.279320478 0.271679461 0.246483982 0.168293342]
[0.24192 0.272276044 0.265186876 0.241610721 0.166722313]
[0.237308443 0.265741408 0.259140551 0.237016559 0.165194541]
[0.23295185 0.25965777 0.253491491 0.232675791 0.163708091]
[0.228827521 0.253975391 0.248197898 0.228565902 0.162261128]
[0.224915475 0.248651937 0.243223906 0.224667087 0.160851941]
[0.221198082 0.243651047 0.238538548 0.220961839 0.159478888]
[0.217659682 0.238941342 0.23411487 0.217434615 0.158140466]
[0.214286327 0.23449555 0.229929343 0.214071587 0.156835243]
[0.211065561 0.230289876 0.225961298 0.210860386 0.155561864]
[0.207986191 0.226303399 0.222192511 0.207789883 0.154319063]
[0.20503816 0.222517684 0.2186068 0.204850093 0.153105617]

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.20221236, 0.2189164 , 0.21518978, 0.20203198, 0.15192041],
      dtype=float32)>

Jika Anda penasaran, Anda dapat memeriksa kode yang dihasilkan tanda tangan.

 print(tf.autograph.to_code(f.python_function))
 
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)


Persyaratan

AutoGraph akan mengonversi beberapa pernyataan if <condition> menjadi panggilan tf.cond setara. Substitusi ini dibuat jika <condition> adalah Tensor. Jika tidak, pernyataan if dijalankan sebagai kondisi Python.

Conditional Python dieksekusi selama tracing, jadi tepat satu cabang conditional akan ditambahkan ke grafik. Tanpa AutoGraph, grafik yang dilacak ini tidak akan dapat mengambil cabang alternatif jika ada aliran kontrol yang bergantung data.

tf.cond melacak dan menambahkan kedua cabang kondisional ke grafik, memilih cabang secara dinamis pada waktu eksekusi. Menelusuri dapat memiliki efek samping yang tidak diinginkan; lihat Efek penelusuran AutoGraph untuk lebih lanjut.

 @tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
 
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

Lihat dokumentasi referensi untuk batasan tambahan pada pernyataan AutoGraph yang dikonversi.

Loop

AutoGraph akan mengonversi beberapa for dan while pernyataan menjadi ops looping TensorFlow yang setara, seperti tf.while_loop . Jika tidak dikonversi, loop for atau while dieksekusi sebagai loop Python.

Substitusi ini dibuat dalam situasi berikut:

  • for x in y : jika y adalah Tensor, konversikan ke tf.while_loop . Dalam kasus khusus di mana y adalah tf.data.Dataset , kombinasi op tf.data.Dataset dihasilkan.
  • while <condition> : jika <condition> adalah Tensor, konversikan ke tf.while_loop .

Loop Python dieksekusi selama tracing, menambahkan ops tambahan ke tf.Graph untuk setiap iterasi loop.

Loop TensorFlow melacak tubuh loop, dan secara dinamis memilih berapa banyak iterasi yang akan dijalankan pada waktu eksekusi. Badan loop hanya muncul sekali di tf.Graph dihasilkan.

Lihat dokumentasi referensi untuk batasan tambahan pada AutoGraph-dikonversi for dan while pernyataan.

Looping melalui data Python

Jebakan yang umum adalah untuk mengulang data Python / Numpy dalam fungsi tf.function . Loop ini akan dieksekusi selama proses penelusuran, menambahkan salinan model Anda ke tf.Graph untuk setiap iterasi loop.

Jika Anda ingin membungkus seluruh loop pelatihan di tf.function , cara paling aman untuk melakukan ini adalah dengan membungkus data Anda sebagai tf.data.Dataset sehingga AutoGraph secara dinamis akan membuka gulungan loop pelatihan.

 def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
 
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph

Saat membungkus data Python / Numpy dalam Dataset, berhati-hatilah dengan tf.data.Dataset.from_generator dibandingkan tf.data.Dataset.from_tensors . Yang pertama akan menyimpan data dalam Python dan mengambilnya melalui tf.py_function yang dapat memiliki implikasi kinerja, sedangkan yang terakhir akan membundel salinan data sebagai satu simpul tf.constant() dalam grafik, yang dapat memiliki implikasi memori.

Membaca data dari file melalui TFRecordDataset / CsvDataset / etc. adalah cara paling efektif untuk mengonsumsi data, karena TensorFlow sendiri dapat mengelola pemuatan dan sinkronisasi data yang asinkron, tanpa harus melibatkan Python. Untuk mempelajari lebih lanjut, lihat panduan tf.data .

Mengumpulkan nilai dalam satu lingkaran

Pola umum adalah untuk mengakumulasikan nilai antara dari satu loop. Biasanya, ini dilakukan dengan menambahkan ke daftar Python atau menambahkan entri ke kamus Python. Namun, karena ini adalah efek samping Python, mereka tidak akan berfungsi seperti yang diharapkan dalam loop yang terbuka secara dinamis. Gunakan tf.TensorArray untuk mengakumulasikan hasil dari loop yang terbuka secara dinamis.

 batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
 
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.2486304 , 0.0612042 , 0.69624186, 0.28587592],
        [1.2193475 , 0.2389338 , 1.5216837 , 0.38649392],
        [1.7640524 , 1.1970762 , 2.3265643 , 0.81419575]],

       [[0.36599267, 0.41830885, 0.73540664, 0.63987565],
        [0.48354673, 1.1808103 , 1.7210082 , 0.8333106 ],
        [0.7138835 , 1.2030114 , 1.8544207 , 1.1647347 ]]], dtype=float32)>

Bacaan lebih lanjut

Untuk mempelajari tentang cara mengekspor dan memuat suatu Function , lihat panduan SavedModel . Untuk mempelajari lebih lanjut tentang optimisasi grafik yang dilakukan setelah penelusuran, lihat panduan Grappler . Untuk mempelajari cara mengoptimalkan saluran data dan profil model Anda, lihat panduan Profiler .