Bergabunglah dengan TensorFlow di Google I/O, 11-12 Mei Daftar sekarang

Performa lebih baik dengan tf.function

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHubUnduh buku catatan

Di TensorFlow 2, eksekusi bersemangat diaktifkan secara default. Antarmuka pengguna intuitif dan fleksibel (menjalankan operasi satu kali jauh lebih mudah dan lebih cepat), tetapi ini dapat mengorbankan kinerja dan kemampuan penerapan.

Anda dapat menggunakan tf.function untuk membuat grafik dari program Anda. Ini adalah alat transformasi yang membuat grafik aliran data Python-independen dari kode Python Anda. Ini akan membantu Anda membuat model yang berkinerja dan portabel, dan itu diperlukan untuk menggunakan SavedModel .

Panduan ini akan membantu Anda mengkonseptualisasikan bagaimana tf.function bekerja di bawah tenda, sehingga Anda dapat menggunakannya secara efektif.

Takeaways utama dan rekomendasi adalah:

  • Debug dalam mode bersemangat, lalu hiasi dengan @tf.function .
  • Jangan mengandalkan efek samping Python seperti mutasi objek atau penambahan daftar.
  • tf.function berfungsi paling baik dengan operasi TensorFlow; Panggilan NumPy dan Python dikonversi ke konstanta.

Mempersiapkan

import tensorflow as tf

Tentukan fungsi pembantu untuk menunjukkan jenis kesalahan yang mungkin Anda temui:

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

Dasar-dasar

Penggunaan

Function yang Anda tentukan (misalnya dengan menerapkan dekorator @tf.function ) sama seperti operasi inti TensorFlow: Anda dapat menjalankannya dengan penuh semangat; Anda dapat menghitung gradien; dan seterusnya.

@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

Anda dapat menggunakan Function s di dalam Function s lainnya.

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Function s bisa lebih cepat daripada kode bersemangat, terutama untuk grafik dengan banyak operasi kecil. Tetapi untuk grafik dengan beberapa operasi mahal (seperti konvolusi), Anda mungkin tidak melihat banyak percepatan.

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006058974999177735
Function conv: 0.005791576000774512
Note how there's not much difference in performance for convolutions

Pelacakan

Bagian ini memaparkan bagaimana Function bekerja di bawah tenda, termasuk detail implementasi yang mungkin berubah di masa mendatang . Namun, begitu Anda memahami mengapa dan kapan penelusuran terjadi, akan jauh lebih mudah untuk menggunakan tf.function secara efektif!

Apa itu "menelusuri"?

Function menjalankan program Anda dalam Grafik TensorFlow . Namun, tf.Graph tidak dapat mewakili semua hal yang akan Anda tulis dalam program TensorFlow yang bersemangat. Misalnya, Python mendukung polimorfisme, tetapi tf.Graph membutuhkan inputnya untuk memiliki tipe dan dimensi data yang ditentukan. Atau Anda dapat melakukan tugas sampingan seperti membaca argumen baris perintah, memunculkan kesalahan, atau bekerja dengan objek Python yang lebih kompleks; tidak satu pun dari hal-hal ini dapat berjalan di tf.Graph .

Function menjembatani kesenjangan ini dengan memisahkan kode Anda dalam dua tahap:

1) Pada tahap pertama, disebut sebagai " tracing ", Function membuat tf.Graph baru. Kode Python berjalan normal, tetapi semua operasi TensorFlow (seperti menambahkan dua Tensor) ditangguhkan : mereka ditangkap oleh tf.Graph dan tidak dijalankan.

2) Pada tahap kedua, tf.Graph yang berisi semua yang ditangguhkan pada tahap pertama dijalankan. Tahap ini jauh lebih cepat daripada tahap tracing.

Bergantung pada inputnya, Function tidak akan selalu menjalankan tahap pertama saat dipanggil. Lihat "Aturan penelusuran" di bawah untuk mendapatkan pemahaman yang lebih baik tentang bagaimana hal itu membuat penentuan itu. Melewati tahap pertama dan hanya menjalankan tahap kedua akan memberi Anda kinerja tinggi TensorFlow.

Ketika Function memutuskan untuk melacak, tahap penelusuran segera diikuti oleh tahap kedua, jadi memanggil Function akan membuat dan menjalankan tf.Graph . Nanti Anda akan melihat bagaimana Anda hanya dapat menjalankan tahap penelusuran dengan get_concrete_function .

Saat Anda meneruskan argumen dari berbagai jenis ke dalam Function , kedua tahapan dijalankan:

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)

Perhatikan bahwa jika Anda berulang kali memanggil Function dengan tipe argumen yang sama, TensorFlow akan melewati tahap penelusuran dan menggunakan kembali grafik yang dilacak sebelumnya, karena grafik yang dihasilkan akan identik.

# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)

Anda dapat menggunakan pretty_printed_concrete_signatures() untuk melihat semua jejak yang tersedia:

print(double.pretty_printed_concrete_signatures())
double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Sejauh ini, Anda telah melihat bahwa tf.function membuat lapisan pengiriman dinamis yang di-cache di atas logika pelacakan grafik TensorFlow. Untuk lebih spesifik tentang terminologi:

  • tf.Graph adalah representasi mentah, agnostik bahasa, portabel dari komputasi TensorFlow.
  • Sebuah ConcreteFunction membungkus tf.Graph .
  • Sebuah Function mengelola cache dari ConcreteFunction s dan memilih yang tepat untuk masukan Anda.
  • tf.function membungkus fungsi Python, mengembalikan objek Function .
  • Tracing membuat tf.Graph dan membungkusnya dalam ConcreteFunction , juga dikenal sebagai trace.

Aturan pelacakan

Function menentukan apakah akan menggunakan kembali ConcreteFunction yang dilacak dengan menghitung kunci cache dari argumen dan kwargs input. Kunci cache adalah kunci yang mengidentifikasi ConcreteFunction berdasarkan argumen input dan kwargs dari pemanggilan Function , menurut aturan berikut (yang dapat berubah):

  • Kunci yang dihasilkan untuk tf.Tensor adalah bentuk dan tipenya.
  • Kunci yang dihasilkan untuk tf.Variable adalah id variabel unik.
  • Kunci yang dihasilkan untuk primitif Python (seperti int , float , str ) adalah nilainya.
  • Kunci yang dihasilkan untuk nested dict s, list s, tuple s, namedtuple s, dan attr s adalah tuple yang diratakan dari leaf-keys (lihat nest.flatten ). (Sebagai hasil dari perataan ini, memanggil fungsi konkret dengan struktur bersarang yang berbeda dari yang digunakan selama penelusuran akan menghasilkan TypeError).
  • Untuk semua jenis Python lainnya, kuncinya unik untuk objek. Dengan cara ini fungsi atau metode dilacak secara independen untuk setiap instance yang dipanggil.

Mengontrol penelusuran kembali

Penelusuran ulang, yaitu saat Function Anda membuat lebih dari satu jejak, membantu memastikan bahwa TensorFlow menghasilkan grafik yang benar untuk setiap rangkaian input. Namun, penelusuran adalah operasi yang mahal! Jika Function Anda menelusuri kembali grafik baru untuk setiap panggilan, Anda akan menemukan bahwa kode Anda dieksekusi lebih lambat daripada jika Anda tidak menggunakan tf.function .

Untuk mengontrol perilaku penelusuran, Anda dapat menggunakan teknik berikut:

  • Tentukan input_signature di tf.function untuk membatasi pelacakan.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/1851403433.py", line 13, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
  • Tentukan dimensi [None] di tf.TensorSpec untuk memungkinkan fleksibilitas dalam penggunaan kembali pelacakan.

    Karena TensorFlow cocok dengan tensor berdasarkan bentuknya, menggunakan dimensi None sebagai karakter pengganti akan memungkinkan Function s menggunakan kembali jejak untuk input berukuran bervariasi. Input dengan ukuran yang bervariasi dapat terjadi jika Anda memiliki urutan dengan panjang yang berbeda, atau gambar dengan ukuran berbeda untuk setiap batch (Lihat tutorial Transformer dan Deep Dream misalnya).

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
  print('Tracing with', x)
  return x

# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
  • Keluarkan argumen Python ke Tensor untuk mengurangi penelusuran ulang.

    Seringkali, argumen Python digunakan untuk mengontrol hiperparameter dan konstruksi grafik - misalnya, num_layers=10 atau training=True atau nonlinearity='relu' . Jadi, jika argumen Python berubah, masuk akal jika Anda harus menelusuri kembali grafik.

    Namun, ada kemungkinan bahwa argumen Python tidak digunakan untuk mengontrol konstruksi grafik. Dalam kasus ini, perubahan nilai Python dapat memicu penelusuran yang tidak perlu. Ambil, misalnya, loop pelatihan ini, yang AutoGraph akan dibuka secara dinamis. Meskipun banyak jejak, grafik yang dihasilkan sebenarnya identik, jadi penelusuran ulang tidak diperlukan.

def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = ", num_steps)
  tf.print("Executing with num_steps = ", num_steps)
  for _ in tf.range(num_steps):
    train_one_step()

print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20

Jika Anda perlu memaksa penelusuran ulang, buat Function baru. Objek Function terpisah dijamin tidak akan berbagi jejak.

def f():
  print('Tracing!')
  tf.print('Executing')

tf.function(f)()
tf.function(f)()
Tracing!
Executing
Tracing!
Executing

Mendapatkan fungsi konkrit

Setiap kali suatu fungsi ditelusuri, fungsi konkret baru dibuat. Anda bisa langsung mendapatkan fungsi konkret, dengan menggunakan get_concrete_function .

print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)

Mencetak ConcreteFunction menampilkan ringkasan argumen inputnya (dengan tipe) dan tipe outputnya.

print(double_strings)
ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()

Anda juga dapat langsung mengambil tanda tangan fungsi konkret.

print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)

Menggunakan jejak beton dengan tipe yang tidak kompatibel akan menimbulkan kesalahan

with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3196284684.py", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]

Anda mungkin memperhatikan bahwa argumen Python diberikan perlakuan khusus dalam tanda tangan input fungsi konkret. Sebelum TensorFlow 2.3, argumen Python dihapus begitu saja dari tanda tangan fungsi konkret. Dimulai dengan TensorFlow 2.3, argumen Python tetap ada di tanda tangan, tetapi dibatasi untuk mengambil nilai yang ditetapkan selama penelusuran.

@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
  square(tf.constant(10.0), b=3)
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1721, in _call_impl
    cancellation_manager)
  File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1765, in _call_with_flat_signature
    raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2310937119.py", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.

Mendapatkan grafik

Setiap fungsi konkret adalah pembungkus yang dapat dipanggil di sekitar tf.Graph . Meskipun mengambil objek tf.Graph yang sebenarnya bukanlah sesuatu yang biasanya perlu Anda lakukan, Anda dapat memperolehnya dengan mudah dari fungsi konkret apa pun.

graph = double_strings.graph
for node in graph.as_graph_def().node:
  print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity

Men-debug

Secara umum, kode debug lebih mudah dalam mode bersemangat daripada di dalam tf.function . Anda harus memastikan bahwa kode Anda dijalankan tanpa kesalahan dalam mode bersemangat sebelum mendekorasi dengan tf.function . Untuk membantu proses debug, Anda dapat memanggil tf.config.run_functions_eagerly(True) untuk menonaktifkan dan mengaktifkan kembali tf.function secara global.

Saat melacak masalah yang hanya muncul dalam tf.function , berikut adalah beberapa tip:

  • Panggilan print Python lama biasa hanya dijalankan selama pelacakan, membantu Anda melacak ketika fungsi Anda dilacak (kembali).
  • panggilan tf.print akan dijalankan setiap saat, dan dapat membantu Anda melacak nilai perantara selama eksekusi.
  • tf.debugging.enable_check_numerics adalah cara mudah untuk melacak di mana NaN dan Inf dibuat.
  • pdb ( debug Python ) dapat membantu Anda memahami apa yang terjadi selama penelusuran. (Peringatan: pdb akan memasukkan Anda ke dalam kode sumber yang diubah AutoGraph.)

Transformasi AutoGraph

AutoGraph adalah library yang aktif secara default di tf.function , dan mengubah subset kode bersemangat Python menjadi operasi TensorFlow yang kompatibel dengan grafik. Ini termasuk aliran kontrol seperti if , for , while .

Operasi TensorFlow seperti tf.cond dan tf.while_loop terus bekerja, tetapi aliran kontrol seringkali lebih mudah untuk ditulis dan dipahami jika ditulis dengan Python.

# A simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))
[0.666458249 0.713946581 0.723879576 0.330758929 0.184087753]
[0.582645297 0.613145649 0.619306684 0.319202513 0.182036072]
[0.524585426 0.546337605 0.550645113 0.308785647 0.18005164]
[0.481231302 0.497770309 0.501003504 0.299331933 0.178130865]
[0.447229207 0.460361809 0.462906033 0.290701121 0.176270396]
[0.419618756 0.430379033 0.432449728 0.282779962 0.174467146]
[0.396609187 0.405638 0.407366514 0.275476 0.172718227]
[0.377043903 0.384762734 0.386234313 0.268712848 0.17102097]
[0.360137492 0.366836458 0.368109286 0.262426734 0.169372901]
[0.345335096 0.351221472 0.352336824 0.256563932 0.167771652]
[0.332231969 0.337458342 0.338446289 0.251078814 0.166215062]
[0.320524871 0.325206399 0.326089561 0.24593246 0.164701089]
[0.309981436 0.314206958 0.31500268 0.241091311 0.163227797]
[0.300420195 0.304259449 0.304981351 0.236526251 0.161793426]
[0.291697085 0.295205742 0.295864582 0.232211992 0.160396278]
[0.283696055 0.286919087 0.287523568 0.228126258 0.159034774]
[0.276322395 0.279296666 0.27985391 0.224249557 0.157707423]
[0.269497961 0.272254 0.272769839 0.220564634 0.15641281]
[0.263157606 0.265720904 0.266200244 0.21705614 0.155149609]
[0.257246554 0.259638608 0.260085613 0.213710397 0.153916568]
[0.251718313 0.25395745 0.254375577 0.210515186 0.152712509]
[0.246533215 0.248635098 0.249027327 0.207459539 0.151536316]
[0.241657034 0.243635193 0.244004101 0.204533577 0.15038693]
[0.237060249 0.238926381 0.239274174 0.201728329 0.149263337]
[0.232717097 0.234481394 0.234810054 0.199035719 0.148164615]
[0.228605017 0.230276451 0.230587661 0.196448416 0.147089839]
[0.224704206 0.226290658 0.22658591 0.193959698 0.14603813]
[0.220997125 0.222505584 0.222786173 0.191563457 0.145008713]
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.21746822, 0.21890487, 0.21917202, 0.18925412, 0.14400077],
      dtype=float32)>

Jika Anda penasaran, Anda dapat memeriksa kode yang dihasilkan tanda tangan.

print(tf.autograph.to_code(f.python_function))
def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)

bersyarat

AutoGraph akan mengonversi beberapa pernyataan if <condition> menjadi panggilan tf.cond yang setara. Substitusi ini dilakukan jika <condition> adalah Tensor. Jika tidak, pernyataan if dieksekusi sebagai kondisional Python.

Kondisional Python dijalankan selama penelusuran, jadi tepat satu cabang kondisional akan ditambahkan ke grafik. Tanpa AutoGraph, grafik yang dilacak ini tidak akan dapat mengambil cabang alternatif jika ada aliran kontrol yang bergantung pada data.

tf.cond melacak dan menambahkan kedua cabang kondisional ke grafik, secara dinamis memilih cabang pada waktu eksekusi. Menelusuri dapat memiliki efek samping yang tidak diinginkan; lihat efek tracing AutoGraph untuk informasi lebih lanjut.

@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz

Lihat dokumentasi referensi untuk pembatasan tambahan pada pernyataan if yang dikonversi-Otomatis.

loop

AutoGraph akan mengonversi beberapa pernyataan for dan while menjadi operasi perulangan TensorFlow yang setara, seperti tf.while_loop . Jika tidak dikonversi, loop for atau while dieksekusi sebagai loop Python.

Penggantian ini dilakukan dalam situasi berikut:

  • for x in y : jika y adalah Tensor, konversikan ke tf.while_loop . Dalam kasus khusus di mana y adalah tf.data.Dataset , kombinasi operasi tf.data.Dataset dihasilkan.
  • while <condition> : jika <condition> adalah Tensor, konversikan ke tf.while_loop .

Loop Python dijalankan selama penelusuran, menambahkan operasi tambahan ke tf.Graph untuk setiap iterasi loop.

Loop TensorFlow melacak isi loop, dan secara dinamis memilih berapa banyak iterasi yang akan dijalankan pada waktu eksekusi. Badan loop hanya muncul sekali dalam tf.Graph yang dihasilkan.

Lihat dokumentasi referensi untuk pembatasan tambahan pada pernyataan for dan while yang dikonversi-Otomatis.

Mengulangi data Python

Perangkap umum adalah mengulang data Python/NumPy dalam tf.function . Loop ini akan dijalankan selama proses penelusuran, menambahkan salinan model Anda ke tf.Graph untuk setiap iterasi loop.

Jika Anda ingin membungkus seluruh loop pelatihan dalam tf.function , cara teraman untuk melakukannya adalah dengan membungkus data Anda sebagai tf.data.Dataset sehingga AutoGraph akan membuka gulungan pelatihan secara dinamis.

def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph

Saat membungkus data Python/NumPy dalam Dataset, perhatikan tf.data.Dataset.from_generator versus tf.data.Dataset.from_tensors . Yang pertama akan menyimpan data dalam Python dan mengambilnya melalui tf.py_function yang dapat memiliki implikasi kinerja, sedangkan yang kedua akan menggabungkan salinan data sebagai satu simpul tf.constant() besar dalam grafik, yang dapat memiliki implikasi memori.

Membaca data dari file melalui TFRecordDataset , CsvDataset , dll. adalah cara paling efektif untuk menggunakan data, karena TensorFlow sendiri dapat mengelola pemuatan asinkron dan pengambilan data sebelumnya, tanpa harus melibatkan Python. Untuk mempelajari lebih lanjut, lihat tf.data : Membuat panduan pipeline input TensorFlow .

Mengumpulkan nilai dalam satu lingkaran

Pola yang umum adalah mengakumulasi nilai antara dari sebuah loop. Biasanya, ini dilakukan dengan menambahkan ke daftar Python atau menambahkan entri ke kamus Python. Namun, karena ini adalah efek samping Python, mereka tidak akan berfungsi seperti yang diharapkan dalam loop yang dibuka secara dinamis. Gunakan tf.TensorArray untuk mengumpulkan hasil dari loop yang dibuka secara dinamis.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.06309307, 0.9938811 , 0.90789986, 0.42136216],
        [0.44997275, 1.9107027 , 1.0716251 , 0.717237  ],
        [0.6026064 , 2.1622117 , 1.4164022 , 1.4153863 ]],

       [[0.04946005, 0.69127274, 0.56848884, 0.22406638],
        [0.8148316 , 1.0278493 , 0.6207781 , 1.1935129 ],
        [0.9178308 , 1.320889  , 0.989761  , 2.0120025 ]]], dtype=float32)>

Keterbatasan

TensorFlow Function memiliki beberapa batasan berdasarkan desain yang harus Anda perhatikan saat mengonversi fungsi Python ke Function .

Menjalankan efek samping Python

Efek samping, seperti mencetak, menambahkan daftar, dan mengubah global, dapat berperilaku tidak terduga di dalam Function , terkadang dijalankan dua kali atau tidak semuanya. Itu hanya terjadi saat pertama kali Anda memanggil Function dengan satu set input. Setelah itu, tf.Graph yang dilacak dieksekusi ulang, tanpa mengeksekusi kode Python.

Aturan umum adalah untuk menghindari mengandalkan efek samping Python dalam logika Anda dan hanya menggunakannya untuk men-debug jejak Anda. Jika tidak, API TensorFlow seperti tf.data , tf.print , tf.summary , tf.Variable.assign , dan tf.TensorArray adalah cara terbaik untuk memastikan kode Anda akan dieksekusi oleh runtime TensorFlow dengan setiap panggilan.

@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2

Jika Anda ingin mengeksekusi kode Python selama setiap pemanggilan Function , tf.py_function adalah pintu keluar. Kelemahan dari tf.py_function adalah tidak portabel atau sangat berkinerja, tidak dapat disimpan dengan SavedModel, dan tidak berfungsi dengan baik dalam pengaturan terdistribusi (multi-GPU, TPU). Selain itu, karena tf.py_function harus disambungkan ke dalam grafik, tf.py_function mentransmisikan semua input/output ke tensor.

Mengubah variabel global dan bebas Python

Mengubah variabel global dan bebas Python dianggap sebagai efek samping Python, jadi itu hanya terjadi selama penelusuran.

external_list = []

@tf.function
def side_effect(x):
  print('Python side effect')
  external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect

Terkadang perilaku yang tidak terduga sangat sulit untuk diperhatikan. Dalam contoh di bawah ini, counter dimaksudkan untuk melindungi kenaikan suatu variabel. Namun karena ini adalah bilangan bulat python dan bukan objek TensorFlow, nilainya ditangkap selama pelacakan pertama. Ketika tf.function digunakan, assign_add akan direkam tanpa syarat di grafik yang mendasarinya. Oleh karena itu v akan bertambah 1, setiap kali fungsi tf.function dipanggil. Masalah ini umum terjadi di antara pengguna yang mencoba memigrasikan kode Tensorflow mode Grpah ke Tensorflow 2 menggunakan dekorator tf.function , ketika efek samping python ( counter dalam contoh) digunakan untuk menentukan operasi apa yang akan dijalankan ( assign_add dalam contoh ). Biasanya, pengguna menyadari hal ini hanya setelah melihat hasil numerik yang mencurigakan, atau kinerja yang jauh lebih rendah dari yang diharapkan (misalnya jika operasi yang dijaga sangat mahal).

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 2, 3
1
2
3

Solusi untuk mencapai perilaku yang diharapkan adalah menggunakan tf.init_scope untuk mengangkat operasi di luar grafik fungsi. Ini memastikan bahwa kenaikan variabel hanya dilakukan satu kali selama waktu penelusuran. Perlu dicatat init_scope memiliki efek samping lain termasuk aliran kontrol yang dibersihkan dan pita gradien. Terkadang penggunaan init_scope bisa menjadi terlalu rumit untuk dikelola secara realistis.

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  @tf.function
  def __call__(self):
    if self.counter == 0:
      # Lifts ops out of function-building graphs
      with tf.init_scope():
        self.counter += 1
        self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1
1
1
1

Singkatnya, sebagai aturan praktis, Anda harus menghindari mutasi objek python seperti bilangan bulat atau wadah seperti daftar yang berada di luar Function . Sebagai gantinya, gunakan argumen dan objek TF. Misalnya, bagian "Mengumpulkan nilai dalam satu lingkaran" memiliki satu contoh bagaimana operasi seperti daftar dapat diimplementasikan.

Anda dapat, dalam beberapa kasus, menangkap dan memanipulasi status jika itu adalah tf.Variable . Ini adalah bagaimana bobot model Keras diperbarui dengan panggilan berulang ke ConcreteFunction yang sama.

Menggunakan iterator dan generator Python

Banyak fitur Python, seperti generator dan iterator, mengandalkan runtime Python untuk melacak status. Secara umum, sementara konstruksi ini bekerja seperti yang diharapkan dalam mode bersemangat, mereka adalah contoh efek samping Python dan karena itu hanya terjadi selama penelusuran.

@tf.function
def buggy_consume_next(iterator):
  tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1
Value: 1
Value: 1

Sama seperti bagaimana TensorFlow memiliki tf.TensorArray khusus untuk konstruksi daftar, TensorFlow memiliki tf.data.Iterator khusus untuk konstruksi iterasi. Lihat bagian tentang transformasi AutoGraph untuk gambaran umum. Selain itu, tf.data API dapat membantu mengimplementasikan pola generator:

@tf.function
def good_consume_next(iterator):
  # This is ok, iterator is a tf.data.Iterator
  tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3

Semua output dari tf.function harus mengembalikan nilai

Dengan pengecualian tf.Variable s, tf.function harus mengembalikan semua outputnya. Mencoba mengakses tensor apa pun secara langsung dari suatu fungsi tanpa melalui nilai pengembalian menyebabkan "kebocoran".

Misalnya, fungsi di bawah ini "membocorkan" tensor a melalui Python global x :

x = None

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return a + 2

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)
3
'Tensor' object has no attribute 'numpy'

Ini benar bahkan jika nilai yang bocor juga dikembalikan:

@tf.function
def leaky_function(a):
  global x
  x = a + 1  # Bad - leaks local tensor
  return x  # Good - uses local tensor

correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy())  # Good - value obtained from function's returns
try:
  x.numpy()  # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
  print(expected)

@tf.function
def captures_leaked_tensor(b):
  b += x  # Bad - `x` is leaked from `leaky_function`
  return b

with assert_raises(TypeError):
  captures_leaked_tensor(tf.constant(2))
2
'Tensor' object has no attribute 'numpy'
Caught expected exception 
  <class 'TypeError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/566849597.py", line 21, in <module>
    captures_leaked_tensor(tf.constant(2))
TypeError: Originated from a graph execution error.

The graph execution error is detected at a node built at (most recent call last):
>>>  File /usr/lib/python3.7/runpy.py, line 193, in _run_module_as_main
>>>  File /usr/lib/python3.7/runpy.py, line 85, in _run_code
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel_launcher.py, line 16, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/traitlets/config/application.py, line 846, in launch_instance
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelapp.py, line 677, in start
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tornado/platform/asyncio.py, line 199, in start
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 534, in run_forever
>>>  File /usr/lib/python3.7/asyncio/base_events.py, line 1771, in _run_once
>>>  File /usr/lib/python3.7/asyncio/events.py, line 88, in _run
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 457, in dispatch_queue
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 446, in process_one
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 353, in dispatch_shell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 648, in execute_request
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/ipkernel.py, line 353, in do_execute
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/zmqshell.py, line 533, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2902, in run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2947, in _run_cell
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/async_helpers.py, line 68, in _pseudo_sync_runner
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3173, in run_cell_async
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3364, in run_ast_nodes
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3444, in run_code
>>>  File /tmp/ipykernel_26244/566849597.py, line 7, in <module>
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 910, in __call__
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 958, in _call
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 781, in _initialize
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3157, in _get_concrete_function_internal_garbage_collected
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3557, in _maybe_define_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3402, in _create_graph_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1143, in func_graph_from_py_func
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 672, in wrapped_fn
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1125, in autograph_handler
>>>  File /tmp/ipykernel_26244/566849597.py, line 4, in leaky_function
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1383, in binary_op_wrapper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py, line 1096, in op_dispatch_handler
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1737, in _add_dispatch
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py, line 476, in add_v2
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py, line 746, in _apply_op_helper
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 691, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 3705, in _create_op_internal
>>>  File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 2101, in __init__

Error detected in node 'add' defined at: File "/tmp/ipykernel_26244/566849597.py", line 4, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

Biasanya, kebocoran seperti ini terjadi saat Anda menggunakan pernyataan atau struktur data Python. Selain membocorkan tensor yang tidak dapat diakses, pernyataan seperti itu juga kemungkinan salah karena dianggap sebagai efek samping Python, dan tidak dijamin untuk dijalankan pada setiap pemanggilan fungsi.

Cara umum untuk membocorkan tensor lokal juga termasuk mengubah koleksi Python eksternal, atau objek:

class MyClass:

  def __init__(self):
    self.field = None

external_list = []
external_object = MyClass()

def leaky_function():
  a = tf.constant(1)
  external_list.append(a)  # Bad - leaks tensor
  external_object.field = a  # Bad - leaks tensor

Fungsi tf.rekursif tidak didukung

Function Rekursif s tidak didukung dan dapat menyebabkan loop tak terbatas. Sebagai contoh,

@tf.function
def recursive_fn(n):
  if n > 0:
    return recursive_fn(n - 1)
  else:
    return 1

with assert_raises(Exception):
  recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
Caught expected exception 
  <class 'Exception'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/2233998312.py", line 9, in <module>
    recursive_fn(tf.constant(5))  # Bad - maximum recursion error.
tensorflow.python.autograph.impl.api.StagingError: in user code:

    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn  *
        return recursive_fn(n - 1)
    File "/tmp/ipykernel_26244/2233998312.py", line 3, in recursive_fn  *
        if n > 0:
    File "/usr/lib/python3.7/abc.py", line 139, in __instancecheck__
        return _abc_instancecheck(cls, instance)

    RecursionError: maximum recursion depth exceeded while calling a Python object

Bahkan jika Function rekursif tampaknya berfungsi, fungsi python akan dilacak beberapa kali dan dapat memiliki implikasi kinerja. Sebagai contoh,

@tf.function
def recursive_fn(n):
  if n > 0:
    print('tracing')
    return recursive_fn(n - 1)
  else:
    return 1

recursive_fn(5)  # Warning - multiple tracings
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>

Masalah Dikenal

Jika Function Anda tidak mengevaluasi dengan benar, kesalahan dapat dijelaskan oleh masalah yang diketahui ini yang direncanakan untuk diperbaiki di masa mendatang.

Bergantung pada variabel global dan bebas Python

Function membuat ConcreteFunction baru ketika dipanggil dengan nilai baru dari argumen Python. Namun, itu tidak melakukannya untuk penutupan Python, global, atau nonlokal dari Function itu. Jika nilainya berubah di antara panggilan ke Function , Function akan tetap menggunakan nilai yang mereka miliki saat dilacak. Ini berbeda dari cara kerja fungsi Python biasa.

Untuk alasan itu, Anda harus mengikuti gaya pemrograman fungsional yang menggunakan argumen alih-alih menutup nama luar.

@tf.function
def buggy_add():
  return 1 + foo

@tf.function
def recommended_add(foo):
  return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(101, shape=(), dtype=int32)

Cara lain untuk memperbarui nilai global adalah dengan membuatnya menjadi tf.Variable dan menggunakan metode Variable.assign sebagai gantinya.

@tf.function
def variable_add():
  return 1 + foo

foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100!
Variable: tf.Tensor(101, shape=(), dtype=int32)

Bergantung pada objek Python

Rekomendasi untuk meneruskan objek Python sebagai argumen ke tf.function memiliki sejumlah masalah yang diketahui, yang diharapkan dapat diperbaiki di masa mendatang. Secara umum, Anda dapat mengandalkan pelacakan yang konsisten jika Anda menggunakan struktur primitif Python atau yang kompatibel dengan tf.nest sebagai argumen atau meneruskan instance objek yang berbeda ke dalam Function . Namun, Function tidak akan membuat jejak baru saat Anda melewati objek yang sama dan hanya mengubah atributnya .

class SimpleModel(tf.Module):
  def __init__(self):
    # These values are *not* tf.Variables.
    self.bias = 0.
    self.weight = 2.

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)

Menggunakan Function yang sama untuk mengevaluasi instance model yang diperbarui akan bermasalah karena model yang diperbarui memiliki kunci cache yang sama dengan model aslinya.

Oleh karena itu, Anda disarankan untuk menulis Function agar tidak bergantung pada atribut objek yang dapat diubah atau membuat objek baru.

Jika itu tidak memungkinkan, satu solusinya adalah membuat Function s baru setiap kali Anda memodifikasi objek Anda untuk memaksa penelusuran ulang:

def evaluate(model, x):
  return model.weight * x + model.bias

new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

Karena penelusuran ulang bisa mahal , Anda dapat menggunakan tf.Variable s sebagai atribut objek, yang dapat dimutasi (tetapi tidak diubah, hati-hati!) untuk efek serupa tanpa perlu penelusuran ulang.

class BetterModel:

  def __init__(self):
    self.bias = tf.Variable(0.)
    self.weight = tf.Variable(2.)

@tf.function
def evaluate(model, x):
  return model.weight * x + model.bias

better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)

Membuat tf.Variabel

Function hanya mendukung tf.Variable s tunggal yang dibuat sekali pada panggilan pertama, dan digunakan kembali di seluruh panggilan fungsi berikutnya. Cuplikan kode di bawah ini akan membuat tf.Variable baru di setiap panggilan fungsi, yang menghasilkan pengecualian ValueError .

Contoh:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  return v

with assert_raises(ValueError):
  f(1.0)
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3018268426.py", line 7, in <module>
    f(1.0)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3018268426.py", line 3, in f  *
        v = tf.Variable(1.0)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

Pola umum yang digunakan untuk mengatasi batasan ini adalah memulai dengan nilai Python None, lalu membuat tf.Variable jika nilainya None:

class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Menggunakan dengan beberapa pengoptimal Keras

Anda mungkin menemukan ValueError: tf.function only supports singleton tf.Variables created on the first call. saat menggunakan lebih dari satu pengoptimal Keras dengan tf.function . Kesalahan ini terjadi karena pengoptimal membuat tf.Variables secara internal saat mereka menerapkan gradien untuk pertama kalinya.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

@tf.function
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
  train_step(w, x, y, opt2)
Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:
Traceback (most recent call last):
  File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises
    yield
  File "/tmp/ipykernel_26244/3167358578.py", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    File "/tmp/ipykernel_26244/3167358578.py", line 9, in train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients  **
        self._create_all_weights(var_list)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 828, in _create_all_weights
        _ = self.iterations
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 835, in __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 995, in iterations
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1202, in add_weight
        aggregation=aggregation)
    File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_utils.py", line 129, in make_variable
        shape=variable_shape if variable_shape else None)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

Jika Anda perlu mengubah pengoptimal selama pelatihan, solusinya adalah membuat Function baru untuk setiap pengoptimal, dengan memanggil ConcreteFunction secara langsung.

opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)

# Not a tf.function.
def train_step(w, x, y, optimizer):
   with tf.GradientTape() as tape:
       L = tf.reduce_sum(tf.square(w*x - y))
   gradients = tape.gradient(L, [w])
   optimizer.apply_gradients(zip(gradients, [w]))

w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
  if i % 2 == 0:
    train_step_1(w, x, y) # `opt1` is not used as a parameter. 
  else:
    train_step_2(w, x, y) # `opt2` is not used as a parameter.

Menggunakan dengan beberapa model Keras

Anda mungkin juga menemukan ValueError: tf.function only supports singleton tf.Variables created on the first call. ketika meneruskan contoh model yang berbeda ke Function yang sama.

Kesalahan ini terjadi karena model Keras (yang bentuk inputnya tidak ditentukan ) dan lapisan Keras membuat tf.Variables s saat pertama kali dipanggil. Anda mungkin mencoba menginisialisasi variabel-variabel tersebut di dalam Function , yang telah dipanggil. Untuk menghindari kesalahan ini, coba panggil model.build(input_shape) untuk menginisialisasi semua bobot sebelum melatih model.

Bacaan lebih lanjut

Untuk mempelajari tentang cara mengekspor dan memuat Function , lihat panduan Model Tersimpan . Untuk mempelajari lebih lanjut tentang pengoptimalan grafik yang dilakukan setelah pelacakan, lihat panduan Grappler . Untuk mempelajari cara mengoptimalkan saluran data dan membuat profil model Anda, lihat panduan Profiler .