Watch talks from the 2019 TensorFlow Dev Summit Watch now

tf.function

View on TensorFlow.org Run in Google Colab View source on GitHub

In TensorFlow 2.0 eager execution is turned on by default. This gets you a very intuitive and flexible user interface (running one-off operations is much easier and faster) but this can come at the expense of performance and deployability.

To get peak performance and to make your model deployable anywhere, we provide tf.function as the tool you can use to make graphs out of your programs.

from __future__ import absolute_import, division, print_function

!pip install -q tensorflow==2.0.0-alpha0
import tensorflow as tf
# A function is like an op

@tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]
<tf.Tensor: id=16, shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>

A tf.function you define is just like a core TensorFlow operation: you can execute it eagerly, you can use it in a graph, it has gradients, etc.

# Functions have gradients

@tf.function
def add(a, b):
  return a + b

v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: id=44, shape=(), dtype=float32, numpy=1.0>
# You can use functions inside functions

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: id=74, shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

Polymorphism

tf.function tries to be as generic as a Python function. You can call Python functions with all sorts of signatures, and Python will usually do something reasonable. tf.function does this type of polymorphism for you even though the underlying TensorFlow graphs it generates are specific to the particular types in its signature.

You can call a function with arguments of different types to see what is happening.

# Functions are polymorphic

@tf.function
def add(a):
  return a + a

print("add 1", add(1))
print("add 1.1", add(1.1))
print("add string tensor", add(tf.constant("a")))
c = add.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
c(a=tf.constant("a"))  # aa
add 1 tf.Tensor(2, shape=(), dtype=int32)
add 1.1 tf.Tensor(2.2, shape=(), dtype=float32)
add string tensor tf.Tensor(b'aa', shape=(), dtype=string)

<tf.Tensor: id=104, shape=(), dtype=string, numpy=b'aa'>
# Functions can be faster than eager code, for graphs with many small ops

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

lstm_cell = tf.keras.layers.LSTMCell(10)

@tf.function
def lstm_fn(input, state):
  return lstm_cell(input, state)

input = tf.zeros([10, 10])
state = [tf.zeros([10, 10])] * 2
# warm up
lstm_cell(input, state); lstm_fn(input, state)
print("eager lstm:", timeit.timeit(lambda: lstm_cell(input, state), number=10))
print("function lstm:", timeit.timeit(lambda: lstm_fn(input, state), number=10))
Eager conv: 0.19106170209124684
Function conv: 0.17481591296382248
Note how there's not much difference in performance for convolutions
eager lstm: 0.0355632440187037
function lstm: 0.0081845159875229

State in tf.function

A very appealing property of functions as the programming model, over a general dataflow graph, is that functions can give the runtime more information about what was the intended behavior of the code.

For example, when writing code which has multiple reads and writes to the same variables, a dataflow graph might not naturally encode the originally intended order of operations. In tf.function, however, because we're converting code which was traced from Python, we know the intended execution order.

This means there's no need to add manual control dependencies; tf.function is smart enough to add the minimal set of necessary and sufficient control dependencies for your code to run correctly.

# Automatic control dependencies

a = tf.Variable(1.0)
b = tf.Variable(2.0)

@tf.function
def f(x, y):
  a.assign(y * b)
  b.assign_add(x * a)
  return a + b

f(1.0, 2.0)  # 10.0
<tf.Tensor: id=1610, shape=(), dtype=float32, numpy=10.0>

Variables

We can use the same idea of leveraging the intended execution order of the code to make variable creation and utilization very easy in tf.function. There is one very important caveat, though, which is that with variables it's possible to write code which behaves different when called eagerly multiple times and when its output tensor is evaluated multiple times.

Here is a simple example:

@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

f(1.) # Note: BROKEN, will throw exception

If you run this with eager execution, you'll always get "2" as the answer; but if you repeatedly evaluate the Tensor obtained from f(1.) in a graph context you'll get increasing numbers.

So tf.function does not allow you to write code like that.

# Non-ambiguous code is ok though

v = tf.Variable(1.0)

@tf.function
def f(x):
  return v.assign_add(x)

f(1.0)  # 2.0
f(2.0)  # 4.0
<tf.Tensor: id=1635, shape=(), dtype=float32, numpy=4.0>
# You can also create variables inside a tf.function as long as we can prove
# that those variables are created only the first time the function is executed.

class C: pass
obj = C(); obj.v = None

@tf.function
def g(x):
  if obj.v is None:
    obj.v = tf.Variable(1.0)
  return obj.v.assign_add(x)

g(1.0)  # 2.0
g(2.0)  # 4.0
<tf.Tensor: id=1689, shape=(), dtype=float32, numpy=4.0>
# Variable initializers can depend on function arguments and on values of other
# variables. We can figure out the right initialization order using the same
# method we use to generate control dependencies.

state = []
@tf.function
def fn(x):
  if not state:
    state.append(tf.Variable(2.0 * x))
    state.append(tf.Variable(state[0] * 3.0))
  return state[0] * x * state[1]

fn(tf.constant(1.0))
fn(tf.constant(3.0))
WARNING: Logging before flag parsing goes to stderr.
W0307 18:49:58.824626 139675184367360 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7f08383348b8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
W0307 18:49:58.830288 139675184367360 tf_logging.py:161] Entity <method-wrapper '__call__' of weakref object at 0x7f0838334d68> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.

WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7f08383348b8> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.
WARNING: Entity <method-wrapper '__call__' of weakref object at 0x7f0838334d68> could not be transformed and will be staged without change. Error details can be found in the logs when running with the env variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the AutoGraph team. Cause: Object conversion is not yet supported. If you are trying to convert code that uses an existing object, try including the creation of that object in the conversion. For example, instead of converting the method of a class, try converting the entire class instead. See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/README.md#using-the-functional-api for more information.

<tf.Tensor: id=1796, shape=(), dtype=float32, numpy=36.0>

Control flow and autograph

While tf.cond and tf.while_loop continue to work with tf.function, we provide a better alternative based on lightweight compilation of your Python code.

The autograph library is fully integrated with tf.function, and it will rewrite conditionals and loops which depend on Tensors to run dynamically in the graph.

# Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([10]))
[0.31508553 0.721380234 0.129950762 ... 0.433342814 0.690496325 0.801644325]
[0.305056483 0.617763579 0.129224166 ... 0.408111185 0.598300755 0.664955139]
[0.295933127 0.549569 0.128509626 ... 0.386867732 0.535839319 0.581651628]
[0.28758648 0.500197113 0.127806813 ... 0.368656754 0.489832 0.523864806]
[0.279911876 0.462272167 0.127115428 ... 0.352816224 0.454083085 0.480677366]
[0.272823513 0.431934237 0.126435146 ... 0.338870734 0.425249487 0.446785957]
[0.266250134 0.406936526 0.125765696 ... 0.326468825 0.401343614 0.419253439]
[0.260132134 0.38586846 0.125106782 ... 0.315344274 0.381098 0.396301299]
[0.254419088 0.367793 0.124458104 ... 0.305291146 0.363660634 0.376779735]
[0.249068141 0.352059752 0.123819441 ... 0.296147227 0.348434299 0.359907568]
[0.244042471 0.338200957 0.123190507 ... 0.287782818 0.33498624 0.345132589]
[0.239310354 0.325870335 0.122571081 ... 0.280092806 0.322993964 0.332051814]
[0.234844238 0.31480518 0.121960916 ... 0.272990972 0.312211543 0.320363194]
[0.230620012 0.304802209 0.121359788 ... 0.266405731 0.302447677 0.309835285]
[0.226616591 0.295701116 0.120767467 ... 0.260277182 0.293550938 0.300287217]
[0.22281535 0.287373602 0.120183729 ... 0.254554749 0.285399795 0.291575402]
[0.219199777 0.279715687 0.119608395 ... 0.249195367 0.277895272 0.283584148]
[0.215755224 0.272641927 0.119041242 ... 0.244162112 0.27095598 0.276219]
[0.212468579 0.266081452 0.118482098 ... 0.239423141 0.264514148 0.269402057]
[0.209328115 0.259974867 0.117930762 ... 0.234950811 0.258512884 0.263068348]
[0.206323296 0.254272 0.117387071 ... 0.230720937 0.252904058 0.257163167]
[0.203444585 0.248930186 0.116850831 ... 0.226712331 0.247646555 0.251640201]
[0.20068343 0.243912742 0.116321884 ... 0.222906306 0.242705107 0.246459827]
[0.198032036 0.239188075 0.11580006 ... 0.219286352 0.238049179 0.241587952]
[0.195483267 0.234728679 0.115285195 ... 0.215837747 0.233652264 0.236995056]
[0.193030685 0.230510592 0.11477714 ... 0.212547362 0.22949113 0.232655436]
[0.190668374 0.22651279 0.114275753 ... 0.20940344 0.225545421 0.228546619]
[0.18839094 0.222716689 0.113780893 ... 0.206395403 0.221797109 0.224648744]
[0.186193377 0.21910584 0.113292404 ... 0.203513727 0.218230233 0.220944375]
[0.184071139 0.215665638 0.11281015 ... 0.20074977 0.214830607 0.217417955]
[0.182020009 0.212383032 0.112334013 ... 0.198095769 0.211585537 0.214055702]
[0.180036098 0.209246337 0.111863859 ... 0.195544571 0.208483592 0.210845172]
[0.178115815 0.20624499 0.111399569 ... 0.193089709 0.205514565 0.207775295]
[0.176255822 0.203369528 0.110941015 ... 0.190725267 0.202669233 0.2048361]
[0.174453 0.200611398 0.110488079 ... 0.188445792 0.199939176 0.202018544]
[0.172704518 0.197962806 0.110040657 ... 0.186246336 0.19731687 0.19931443]
[0.171007678 0.195416689 0.109598637 ... 0.184122294 0.19479534 0.196716368]
[0.169359967 0.192966595 0.109161898 ... 0.18206948 0.192368299 0.194217548]
[0.167759076 0.190606609 0.108730339 ... 0.18008396 0.190029979 0.19181183]
[0.166202813 0.188331351 0.10830386 ... 0.178162158 0.187775105 0.189493567]
[0.164689153 0.186135858 0.107882373 ... 0.176300719 0.185598835 0.187257543]
[0.163216189 0.184015557 0.107465766 ... 0.174496546 0.183496684 0.18509905]
[0.161782116 0.181966275 0.107053943 ... 0.172746763 0.181464508 0.183013678]
[0.160385266 0.179984108 0.106646836 ... 0.171048686 0.179498553 0.180997387]
[0.159024045 0.178065449 0.106244348 ... 0.169399813 0.177595273 0.179046452]
[0.157696947 0.176207 0.10584639 ... 0.167797789 0.175751403 0.177157387]
[0.156402573 0.174405679 0.105452858 ... 0.166240469 0.173963904 0.175327]
[0.15513961 0.172658607 0.105063692 ... 0.16472578 0.172229961 0.173552319]
[0.153906822 0.170963109 0.104678795 ... 0.163251832 0.170546949 0.171830565]
[0.152702987 0.169316694 0.104298107 ... 0.161816835 0.168912426 0.170159146]
[0.151527032 0.167717025 0.103921548 ... 0.160419092 0.167324096 0.168535665]
[0.150377855 0.16616194 0.103549033 ... 0.159057021 0.165779829 0.166957855]
[0.149254471 0.164649397 0.103180498 ... 0.157729104 0.164277628 0.165423632]
[0.148155943 0.163177475 0.102815881 ... 0.156433955 0.162815601 0.163931027]
[0.147081345 0.161744431 0.102455102 ... 0.155170262 0.161391988 0.162478164]
[0.146029845 0.160348549 0.102098092 ... 0.153936744 0.160005137 0.161063328]
[0.145000607 0.158988252 0.101744801 ... 0.152732223 0.158653498 0.159684896]
[0.143992841 0.157662064 0.10139516 ... 0.151555583 0.157335594 0.158341318]
[0.143005833 0.156368554 0.101049095 ... 0.15040575 0.156050041 0.157031134]
[0.142038882 0.155106425 0.100706555 ... 0.14928177 0.154795557 0.155752987]
[0.141091302 0.153874427 0.100367464 ... 0.148182631 0.15357089 0.154505596]
[0.140162453 0.152671367 0.100031786 ... 0.147107452 0.152374893 0.153287753]
[0.139251739 0.151496127 0.09969946 ... 0.146055371 0.151206434 0.152098328]
[0.138358578 0.150347665 0.09937042 ... 0.145025581 0.150064498 0.150936186]
[0.137482405 0.149224967 0.0990446135 ... 0.144017294 0.148948088 0.14980033]
[0.136622682 0.148127079 0.098722 ... 0.143029779 0.147856265 0.148689777]
[0.135778904 0.147053108 0.09840253 ... 0.142062351 0.14678815 0.147603586]
[0.134950593 0.146002188 0.0980861336 ... 0.141114324 0.145742878 0.146540895]
[0.134137273 0.144973531 0.0977727696 ... 0.140185028 0.144719645 0.145500854]
[0.133338511 0.143966332 0.0974623859 ... 0.139273882 0.143717706 0.144482672]
[0.132553861 0.14297986 0.0971549526 ... 0.138380289 0.142736316 0.143485621]
[0.131782919 0.142013431 0.0968504101 ... 0.137503698 0.141774788 0.142508969]
[0.131025285 0.141066357 0.0965487137 ... 0.136643589 0.140832454 0.141552]
[0.130280584 0.140138 0.0962498263 ... 0.135799438 0.139908686 0.140614077]
[0.12954846 0.139227778 0.0959536955 ... 0.134970754 0.139002889 0.139694586]
[0.128828555 0.138335079 0.0956602916 ... 0.134157076 0.138114482 0.138792917]
[0.128120542 0.137459353 0.0953695625 ... 0.133357972 0.137242913 0.137908503]
[0.127424076 0.136600062 0.0950814635 ... 0.132572979 0.136387661 0.137040794]
[0.126738861 0.135756716 0.0947959647 ... 0.13180171 0.135548234 0.136189297]
[0.126064584 0.134928823 0.094513014 ... 0.131043762 0.134724125 0.135353491]
[0.125400975 0.13411589 0.094232589 ... 0.130298749 0.133914873 0.134532914]
[0.124747746 0.133317515 0.0939546525 ... 0.129566327 0.133120045 0.133727089]
[0.124104619 0.132533237 0.0936791524 ... 0.128846124 0.132339224 0.132935598]
[0.123471349 0.131762654 0.0934060663 ... 0.128137812 0.131572 0.132158011]
[0.122847699 0.131005362 0.0931353569 ... 0.127441064 0.13081798 0.131393924]
[0.122233413 0.130261 0.092866987 ... 0.12675558 0.130076796 0.130642951]
[0.121628255 0.129529208 0.0926009342 ... 0.126081049 0.129348084 0.129904717]
[0.121032007 0.128809616 0.0923371464 ... 0.125417173 0.128631487 0.129178882]
[0.120444454 0.1281019 0.0920756 ... 0.124763697 0.127926692 0.128465086]
[0.11986538 0.127405748 0.0918162614 ... 0.12412034 0.127233371 0.127763018]
[0.119294584 0.126720816 0.0915591121 ... 0.123486839 0.126551211 0.127072334]
[0.118731871 0.126046836 0.091304116 ... 0.12286295 0.125879928 0.126392752]
[0.118177064 0.125383511 0.0910512358 ... 0.122248426 0.125219211 0.125723958]
[0.117629968 0.12473055 0.0908004493 ... 0.121643052 0.124568805 0.125065684]
[0.117090404 0.124087699 0.0905517191 ... 0.121046595 0.123928435 0.124417655]
[0.116558202 0.123454705 0.0903050229 ... 0.120458826 0.123297848 0.12377961]
[0.116033196 0.1228313 0.0900603384 ... 0.119879536 0.122676805 0.123151287]
[0.115515232 0.12221726 0.0898176283 ... 0.119308546 0.122065067 0.122532435]
[0.115004137 0.121612348 0.0895768702 ... 0.118745647 0.121462405 0.121922843]
[0.114499778 0.121016338 0.0893380344 ... 0.118190646 0.120868579 0.121322267]
[0.114002012 0.120429017 0.0891011059 ... 0.117643356 0.120283395 0.120730489]
[0.113510691 0.119850159 0.088866055 ... 0.117103606 0.119706623 0.120147295]
[0.113025658 0.119279578 0.0886328518 ... 0.116571225 0.119138099 0.119572476]
[0.112546802 0.118717082 0.0884014815 ... 0.116046049 0.118577592 0.119005844]
[0.112073995 0.118162483 0.0881719142 ... 0.115527913 0.118024915 0.118447199]
[0.11160709 0.117615595 0.0879441202 ... 0.115016662 0.117479913 0.117896356]
[0.111145981 0.117076233 0.0877180845 ... 0.114512131 0.116942406 0.117353126]
[0.110690549 0.116544224 0.0874937847 ... 0.114014208 0.116412207 0.116817355]
[0.110240668 0.116019413 0.0872712061 ... 0.113522716 0.115889177 0.116288856]
[0.109796233 0.115501635 0.0870503113 ... 0.113037549 0.11537312 0.115767471]
[0.109357141 0.114990734 0.0868310854 ... 0.112558544 0.114863925 0.115253046]
[0.108923271 0.114486545 0.0866135135 ... 0.112085573 0.11436139 0.114745431]
[0.108494535 0.113988958 0.0863975659 ... 0.111618526 0.113865413 0.114244461]
[0.108070821 0.113497794 0.0861832276 ... 0.111157276 0.11337585 0.11375]
[0.107652038 0.113012932 0.0859704688 ... 0.110701703 0.112892538 0.113261916]
[0.107238084 0.11253424 0.0857592896 ... 0.110251695 0.112415373 0.112780064]
[0.106828876 0.112061583 0.0855496675 ... 0.109807134 0.111944199 0.112304308]
[0.106424324 0.111594833 0.0853415579 ... 0.109367907 0.11147891 0.111834526]
[0.10602434 0.111133881 0.0851349682 ... 0.108933918 0.11101938 0.111370601]
[0.105628826 0.110678583 0.084929876 ... 0.108505055 0.110565498 0.110912412]
[0.105237715 0.110228859 0.0847262591 ... 0.108081214 0.110117123 0.110459827]
[0.104850918 0.109784573 0.0845241 ... 0.107662305 0.109674171 0.11001274]
[0.104468361 0.109345615 0.0843233839 ... 0.107248232 0.109236538 0.109571047]
[0.10408996 0.108911879 0.0841240808 ... 0.106838904 0.108804092 0.109134644]
[0.103715651 0.108483262 0.083926186 ... 0.106434241 0.108376756 0.108703405]
[0.103345349 0.108059682 0.0837296844 ... 0.106034137 0.107954413 0.108277254]
[0.102978989 0.107641019 0.0835345611 ... 0.105638526 0.107536972 0.10785608]
[0.102616489 0.107227206 0.0833407938 ... 0.105247319 0.107124351 0.107439786]
[0.102257803 0.106818117 0.0831483677 ... 0.104860418 0.106716447 0.107028268]
[0.101902857 0.106413677 0.0829572603 ... 0.104477756 0.106313154 0.106621452]
[0.101551577 0.106013812 0.0827674642 ... 0.104099244 0.105914414 0.10621924]
[0.101203911 0.105618417 0.082578972 ... 0.10372483 0.105520129 0.10582155]

<tf.Tensor: id=1841, shape=(10,), dtype=float32, numpy=
array([0.10085979, 0.10522742, 0.08239178, 0.10422876, 0.10471042,
       0.10409579, 0.08315843, 0.10335443, 0.10513023, 0.10542831],
      dtype=float32)>
# If you're curious you can inspect the code autograph generates.
# It feels like reading assembly language, though.

def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

print(tf.autograph.to_code(f))
from __future__ import print_function

def tf__f(x):
  try:
    with ag__.function_scope('f'):
      do_return = False
      retval_ = None

      def loop_test(x_1):
        with ag__.function_scope('loop_test'):
          return ag__.gt(ag__.converted_call('reduce_sum', tf, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=ag__.Feature.ALL, internal_convert_user_code=True), (x_1,), {}), 1)

      def loop_body(x_1):
        with ag__.function_scope('loop_body'):
          with ag__.utils.control_dependency_on_returns(ag__.converted_call('print', tf, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=ag__.Feature.ALL, internal_convert_user_code=True), (x_1,), {})):
            tf_1, x = ag__.utils.alias_tensors(tf, x_1)
            x = ag__.converted_call('tanh', tf_1, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=ag__.Feature.ALL, internal_convert_user_code=True), (x,), {})
            return x,
      x, = ag__.while_stmt(loop_test, loop_body, (x,), (tf, x, ag__))
      do_return = True
      retval_ = x
      return retval_
  except:
    ag__.rewrite_graph_construction_error(ag_source_map__)



tf__f.autograph_info__ = {}

To control autograph, remember that it only affects the basic control flow constructs in Python (if, for, while, break, etc) and that it only changes them if the predicates are Tensors.

So in the following example the first loop is statically unrolled while the second loop is dynamically converted:

@tf.function
def f(x):
  for i in range(10):  # Static python loop, we'll not convert it
    do_stuff()
  for i in tf.range(10):  # depends on a tensor, we'll convert it

Similarly, to guarantee that prints and asserts happen dynamically, use tf.print and tf.assert:

@tf.function
def f(x):
  for i in tf.range(10):
    tf.print(i)
    tf.Assert(i < 10, ["a"])
    x += x
  return x

f(10)
0
1
2
3
4
5
6
7
8
9

<tf.Tensor: id=1904, shape=(), dtype=int32, numpy=10240>

Finally, autograph cannot compile arbitrary Python code into TensorFlow graphs. Specifically, the data structures which you use dynamically still need to be TensorFlow data structures.

So, for example, the best way to accumulate data in a loop is still to use tf.TensorArray:

@tf.function
def f(x):
  ta = tf.TensorArray(tf.float32, size=10)
  for i in tf.range(10):
    x += x
    ta = ta.write(i, x)
  return ta.stack()

f(10.0)
<tf.Tensor: id=1973, shape=(10,), dtype=float32, numpy=
array([   20.,    40.,    80.,   160.,   320.,   640.,  1280.,  2560.,
        5120., 10240.], dtype=float32)>

Next steps

Now revisit the earlier notebooks and try using tf.function to speed up your code!