![]() |
![]() |
![]() |
![]() |
This guide is for users of low-level TensorFlow APIs. If you are using the high-level APIs (tf.keras
) there may be little or no action you need to take to make your code fully TensorFlow 2.x compatible:
- Check your optimizer's default learning rate.
- Note that the "name" that metrics are logged to may have changed.
It is still possible to run 1.x code, unmodified (except for contrib), in TensorFlow 2.x:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
However, this does not let you take advantage of many of the improvements made in TensorFlow 2.x. This guide will help you upgrade your code, making it simpler, more performant, and easier to maintain.
Automatic conversion script
The first step, before attempting to implement the changes described in this guide, is to try running the upgrade script.
This will execute an initial pass at upgrading your code to TensorFlow 2.x but it can't make your code idiomatic to v2. Your code may still make use of tf.compat.v1
endpoints to access placeholders, sessions, collections, and other 1.x-style functionality.
Top-level behavioral changes
If your code works in TensorFlow 2.x using tf.compat.v1.disable_v2_behavior
, there are still global behavioral changes you may need to address. The major changes are:
Eager execution,
v1.enable_eager_execution()
: Any code that implicitly uses atf.Graph
will fail. Be sure to wrap this code in awith tf.Graph().as_default()
context.Resource variables,
v1.enable_resource_variables()
: Some code may depends on non-deterministic behaviors enabled by TensorFlow reference variables. Resource variables are locked while being written to, and so provide more intuitive consistency guarantees.- This may change behavior in edge cases.
- This may create extra copies and can have higher memory usage.
- This can be disabled by passing
use_resource=False
to thetf.Variable
constructor.
Tensor shapes,
v1.enable_v2_tensorshape()
: TensorFlow 2.x simplifies the behavior of tensor shapes. Instead oft.shape[0].value
you can sayt.shape[0]
. These changes should be small, and it makes sense to fix them right away. Refer to the TensorShape section for examples.Control flow,
v1.enable_control_flow_v2()
: The TensorFlow 2.x control flow implementation has been simplified, and so produces different graph representations. Please file bugs for any issues.
Create code for TensorFlow 2.x
This guide will walk through several examples of converting TensorFlow 1.x code to TensorFlow 2.x. These changes will let your code take advantage of performance optimizations and simplified API calls.
In each case, the pattern is:
1. Replace v1.Session.run
calls
Every v1.Session.run
call should be replaced by a Python function.
- The
feed_dict
andv1.placeholder
s become function arguments. - The
fetches
become the function's return value. - During conversion eager execution allows easy debugging with standard Python tools like
pdb
.
After that, add a tf.function
decorator to make it run efficiently in graph. Check out the Autograph guide for more information about how this works.
Note that:
Unlike
v1.Session.run
, atf.function
has a fixed return signature and always returns all outputs. If this causes performance problems, create two separate functions.There is no need for a
tf.control_dependencies
or similar operations: Atf.function
behaves as if it were run in the order written.tf.Variable
assignments andtf.assert
s, for example, are executed automatically.
The converting models section contains a working example of this conversion process.
2. Use Python objects to track variables and losses
All name-based variable tracking is strongly discouraged in TensorFlow 2.x. Use Python objects to to track variables.
Use tf.Variable
instead of v1.get_variable
.
Every v1.variable_scope
should be converted to a Python object. Typically this will be one of:
If you need to aggregate lists of variables (like tf.Graph.get_collection(tf.GraphKeys.VARIABLES)
), use the .variables
and .trainable_variables
attributes of the Layer
and Model
objects.
These Layer
and Model
classes implement several other properties that remove the need for global collections. Their .losses
property can be a replacement for using the tf.GraphKeys.LOSSES
collection.
Refer to the Keras guides for more details.
3. Upgrade your training loops
Use the highest-level API that works for your use case. Prefer tf.keras.Model.fit
over building your own training loops.
These high level functions manage a lot of the low-level details that might be easy to miss if you write your own training loop. For example, they automatically collect the regularization losses, and set the training=True
argument when calling the model.
4. Upgrade your data input pipelines
Use tf.data
datasets for data input. These objects are efficient, expressive, and integrate well with tensorflow.
They can be passed directly to the tf.keras.Model.fit
method.
model.fit(dataset, epochs=5)
They can be iterated over directly standard Python:
for example_batch, label_batch in dataset:
break
5. Migrate off compat.v1
symbols
The tf.compat.v1
module contains the complete TensorFlow 1.x API, with its original semantics.
The TensorFlow 2.x upgrade script will convert symbols to their v2 equivalents if such a conversion is safe, i.e., if it can determine that the behavior of the TensorFlow 2.x version is exactly equivalent (for instance, it will rename v1.arg_max
to tf.argmax
, since those are the same function).
After the upgrade script is done with a piece of code, it is likely there are many mentions of compat.v1
. It is worth going through the code and converting these manually to the v2 equivalent (it should be mentioned in the log if there is one).
Converting models
Low-level variables & operator execution
Examples of low-level API use include:
- Using variable scopes to control reuse.
- Creating variables with
v1.get_variable
. - Accessing collections explicitly.
Accessing collections implicitly with methods like:
Using
v1.placeholder
to set up graph inputs.Executing graphs with
Session.run
.Initializing variables manually.
Before converting
Here is what these patterns may look like in code using TensorFlow 1.x.
import tensorflow as tf
import tensorflow.compat.v1 as v1
import tensorflow_datasets as tfds
g = v1.Graph()
with g.as_default():
in_a = v1.placeholder(dtype=v1.float32, shape=(2))
in_b = v1.placeholder(dtype=v1.float32, shape=(2))
def forward(x):
with v1.variable_scope("matmul", reuse=v1.AUTO_REUSE):
W = v1.get_variable("W", initializer=v1.ones(shape=(2,2)),
regularizer=lambda x:tf.reduce_mean(x**2))
b = v1.get_variable("b", initializer=v1.zeros(shape=(2)))
return W * x + b
out_a = forward(in_a)
out_b = forward(in_b)
reg_loss=v1.losses.get_regularization_loss(scope="matmul")
with v1.Session(graph=g) as sess:
sess.run(v1.global_variables_initializer())
outs = sess.run([out_a, out_b, reg_loss],
feed_dict={in_a: [1, 0], in_b: [0, 1]})
print(outs[0])
print()
print(outs[1])
print()
print(outs[2])
[[1. 0.] [1. 0.]] [[0. 1.] [0. 1.]] 1.0
After converting
In the converted code:
- The variables are local Python objects.
- The
forward
function still defines the calculation. - The
Session.run
call is replaced with a call toforward
. - The optional
tf.function
decorator can be added for performance. - The regularizations are calculated manually, without referring to any global collection.
- There's no usage of sessions or placeholders.
W = tf.Variable(tf.ones(shape=(2,2)), name="W")
b = tf.Variable(tf.zeros(shape=(2)), name="b")
@tf.function
def forward(x):
return W * x + b
out_a = forward([1,0])
print(out_a)
tf.Tensor( [[1. 0.] [1. 0.]], shape=(2, 2), dtype=float32)
out_b = forward([0,1])
regularizer = tf.keras.regularizers.l2(0.04)
reg_loss=regularizer(W)
Models based on tf.layers
The v1.layers
module is used to contain layer-functions that relied on v1.variable_scope
to define and reuse variables.
Before converting
def model(x, training, scope='model'):
with v1.variable_scope(scope, reuse=v1.AUTO_REUSE):
x = v1.layers.conv2d(x, 32, 3, activation=v1.nn.relu,
kernel_regularizer=lambda x:0.004*tf.reduce_mean(x**2))
x = v1.layers.max_pooling2d(x, (2, 2), 1)
x = v1.layers.flatten(x)
x = v1.layers.dropout(x, 0.1, training=training)
x = v1.layers.dense(x, 64, activation=v1.nn.relu)
x = v1.layers.batch_normalization(x, training=training)
x = v1.layers.dense(x, 10)
return x
train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))
train_out = model(train_data, training=True)
test_out = model(test_data, training=False)
print(train_out)
print()
print(test_out)
/tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/convolutional.py:414: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead. warnings.warn('`tf.layers.conv2d` is deprecated and ' /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py:2273: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead. warnings.warn('`layer.apply` is deprecated and ' tf.Tensor([[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(1, 10), dtype=float32) tf.Tensor( [[ 0.379358 -0.55901194 0.48704922 0.11619566 0.23902717 0.01691487 0.07227738 0.14556988 0.2459927 0.2501198 ]], shape=(1, 10), dtype=float32) /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/pooling.py:310: UserWarning: `tf.layers.max_pooling2d` is deprecated and will be removed in a future version. Please use `tf.keras.layers.MaxPooling2D` instead. warnings.warn('`tf.layers.max_pooling2d` is deprecated and ' /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:329: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead. warnings.warn('`tf.layers.flatten` is deprecated and ' /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:268: UserWarning: `tf.layers.dropout` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dropout` instead. warnings.warn('`tf.layers.dropout` is deprecated and ' /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:171: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead. warnings.warn('`tf.layers.dense` is deprecated and ' /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/legacy_tf_layers/normalization.py:308: UserWarning: `tf.layers.batch_normalization` is deprecated and will be removed in a future version. Please use `tf.keras.layers.BatchNormalization` instead. In particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not be used (consult the `tf.keras.layers.BatchNormalization` documentation). '`tf.layers.batch_normalization` is deprecated and '
After converting
- The simple stack of layers fits neatly into
tf.keras.Sequential
. (For more complex models, check out the custom layers and models and the functional API guides.) - The model tracks the variables, and regularization losses.
- The conversion was one-to-one because there is a direct mapping from
v1.layers
totf.keras.layers
.
Most arguments stayed the same. But notice the differences:
- The
training
argument is passed to each layer by the model when it runs. - The first argument to the original
model
function (the inputx
) is gone. This is because object layers separate building the model from calling the model.
Also note that:
- If you are using regularizers or initializers from
tf.contrib
, these have more argument changes than others. - The code no longer writes to collections, so functions like
v1.losses.get_regularization_loss
will no longer return these values, potentially breaking your training loops.
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.04),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))
train_out = model(train_data, training=True)
print(train_out)
tf.Tensor([[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(1, 10), dtype=float32)
test_out = model(test_data, training=False)
print(test_out)
tf.Tensor( [[-0.2145557 -0.22979769 -0.14968733 0.01208701 -0.07569927 0.3475932 0.10718458 0.03482988 -0.04309493 -0.10469118]], shape=(1, 10), dtype=float32)
# Here are all the trainable variables
len(model.trainable_variables)
8
# Here is the regularization loss
model.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.08174552>]
Mixed variables & v1.layers
Existing code often mixes lower-level TensorFlow 1.x variables and operations with higher-level v1.layers
.
Before converting
def model(x, training, scope='model'):
with v1.variable_scope(scope, reuse=v1.AUTO_REUSE):
W = v1.get_variable(
"W", dtype=v1.float32,
initializer=v1.ones(shape=x.shape),
regularizer=lambda x:0.004*tf.reduce_mean(x**2),
trainable=True)
if training:
x = x + W
else:
x = x + W * 0.5
x = v1.layers.conv2d(x, 32, 3, activation=tf.nn.relu)
x = v1.layers.max_pooling2d(x, (2, 2), 1)
x = v1.layers.flatten(x)
return x
train_out = model(train_data, training=True)
test_out = model(test_data, training=False)
After converting
To convert this code, follow the pattern of mapping layers to layers as in the previous example.
The general pattern is:
- Collect layer parameters in
__init__
. - Build the variables in
build
. - Execute the calculations in
call
, and return the result.
The v1.variable_scope
is essentially a layer of its own. So rewrite it as a tf.keras.layers.Layer
. Check out the Making new Layers and Models via subclassing guide for details.
# Create a custom layer for part of the model
class CustomLayer(tf.keras.layers.Layer):
def __init__(self, *args, **kwargs):
super(CustomLayer, self).__init__(*args, **kwargs)
def build(self, input_shape):
self.w = self.add_weight(
shape=input_shape[1:],
dtype=tf.float32,
initializer=tf.keras.initializers.ones(),
regularizer=tf.keras.regularizers.l2(0.02),
trainable=True)
# Call method will sometimes get used in graph mode,
# training will get turned into a tensor
@tf.function
def call(self, inputs, training=None):
if training:
return inputs + self.w
else:
return inputs + self.w * 0.5
custom_layer = CustomLayer()
print(custom_layer([1]).numpy())
print(custom_layer([1], training=True).numpy())
[1.5] [2.]
train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))
# Build the model including the custom layer
model = tf.keras.Sequential([
CustomLayer(input_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
])
train_out = model(train_data, training=True)
test_out = model(test_data, training=False)
Some things to note:
Subclassed Keras models and layers need to run in both v1 graphs (no automatic control dependencies) and in eager mode:
- Wrap the
call
in atf.function
to get autograph and automatic control dependencies.
- Wrap the
Don't forget to accept a
training
argument tocall
:- Sometimes it is a
tf.Tensor
- Sometimes it is a Python boolean
- Sometimes it is a
Create model variables in constructor or
Model.build
using `self.add_weight:- In
Model.build
you have access to the input shape, so can create weights with matching shape - Using
tf.keras.layers.Layer.add_weight
allows Keras to track variables and regularization losses
- In
Don't keep
tf.Tensors
in your objects:- They might get created either in a
tf.function
or in the eager context, and these tensors behave differently - Use
tf.Variable
s for state, they are always usable from both contexts tf.Tensors
are only for intermediate values
- They might get created either in a
A note on Slim and contrib.layers
A large amount of older TensorFlow 1.x code uses the Slim library, which was packaged with TensorFlow 1.x as tf.contrib.layers
. As a contrib
module, this is no longer available in TensorFlow 2.x, even in tf.compat.v1
. Converting code using Slim to TensorFlow 2.x is more involved than converting repositories that use v1.layers
. In fact, it may make sense to convert your Slim code to v1.layers
first, then convert to Keras.
- Remove
arg_scopes
, all args need to be explicit. - If you use them, split
normalizer_fn
andactivation_fn
into their own layers. - Separable conv layers map to one or more different Keras layers (depthwise, pointwise, and separable Keras layers).
- Slim and
v1.layers
have different argument names and default values. - Some args have different scales.
- If you use Slim pre-trained models, try out Keras's pre-traimed models from
tf.keras.applications
or TF Hub's TensorFlow 2.x SavedModels exported from the original Slim code.
Some tf.contrib
layers might not have been moved to core TensorFlow but have instead been moved to the TensorFlow Addons package.
Training
There are many ways to feed data to a tf.keras
model. They will accept Python generators and Numpy arrays as input.
The recommended way to feed data to a model is to use the tf.data
package, which contains a collection of high performance classes for manipulating data.
If you are still using tf.queue
, these are now only supported as data-structures, not as input pipelines.
Using TensorFlow Datasets
The TensorFlow Datasets package (tfds
) contains utilities for loading predefined datasets as tf.data.Dataset
objects.
For this example, you can load the MNIST dataset using tfds
:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/3.0.1... WARNING:absl:Dataset mnist is hosted on GCS. It will automatically be downloaded to your local data directory. If you'd instead prefer to read directly from our public GCS bucket (recommended if you're running on GCP), you can instead pass `try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`. Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Then prepare the data for training:
- Re-scale each image.
- Shuffle the order of the examples.
- Collect batches of images and labels.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
To keep the example short, trim the dataset to only return 5 batches:
train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)
STEPS_PER_EPOCH = 5
train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
Use Keras training loops
If you don't need low-level control of your training process, using Keras's built-in fit
, evaluate
, and predict
methods is recommended. These methods provide a uniform interface to train the model regardless of the implementation (sequential, functional, or sub-classed).
The advantages of these methods include:
- They accept Numpy arrays, Python generators and,
tf.data.Datasets
. - They apply regularization, and activation losses automatically.
- They support
tf.distribute
for multi-device training. - They support arbitrary callables as losses and metrics.
- They support callbacks like
tf.keras.callbacks.TensorBoard
, and custom callbacks. - They are performant, automatically using TensorFlow graphs.
Here is an example of training a model using a Dataset
. (For details on how this works, check out the tutorials section.)
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5 5/5 [==============================] - 1s 9ms/step - loss: 2.0191 - accuracy: 0.3608 Epoch 2/5 5/5 [==============================] - 0s 9ms/step - loss: 0.4736 - accuracy: 0.9059 Epoch 3/5 5/5 [==============================] - 0s 8ms/step - loss: 0.2973 - accuracy: 0.9626 Epoch 4/5 5/5 [==============================] - 0s 9ms/step - loss: 0.2108 - accuracy: 0.9911 Epoch 5/5 5/5 [==============================] - 0s 8ms/step - loss: 0.1791 - accuracy: 0.9874 5/5 [==============================] - 0s 6ms/step - loss: 1.5504 - accuracy: 0.7500 Loss 1.5504140853881836, Accuracy 0.75
Write your own loop
If the Keras model's training step works for you, but you need more control outside that step, consider using the tf.keras.Model.train_on_batch
method, in your own data-iteration loop.
Remember: Many things can be implemented as a tf.keras.callbacks.Callback
.
This method has many of the advantages of the methods mentioned in the previous section, but gives the user control of the outer loop.
You can also use tf.keras.Model.test_on_batch
or tf.keras.Model.evaluate
to check performance during training.
To continue training the above model:
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
for epoch in range(NUM_EPOCHS):
# Reset the metric accumulators
model.reset_metrics()
for image_batch, label_batch in train_data:
result = model.train_on_batch(image_batch, label_batch)
metrics_names = model.metrics_names
print("train: ",
"{}: {:.3f}".format(metrics_names[0], result[0]),
"{}: {:.3f}".format(metrics_names[1], result[1]))
for image_batch, label_batch in test_data:
result = model.test_on_batch(image_batch, label_batch,
# Return accumulated metrics
reset_metrics=False)
metrics_names = model.metrics_names
print("\neval: ",
"{}: {:.3f}".format(metrics_names[0], result[0]),
"{}: {:.3f}".format(metrics_names[1], result[1]))
train: loss: 0.138 accuracy: 1.000 train: loss: 0.161 accuracy: 1.000 train: loss: 0.159 accuracy: 0.969 train: loss: 0.241 accuracy: 0.953 train: loss: 0.172 accuracy: 0.969 eval: loss: 1.550 accuracy: 0.800 train: loss: 0.086 accuracy: 1.000 train: loss: 0.094 accuracy: 1.000 train: loss: 0.090 accuracy: 1.000 train: loss: 0.119 accuracy: 0.984 train: loss: 0.099 accuracy: 1.000 eval: loss: 1.558 accuracy: 0.841 train: loss: 0.076 accuracy: 1.000 train: loss: 0.068 accuracy: 1.000 train: loss: 0.061 accuracy: 1.000 train: loss: 0.076 accuracy: 1.000 train: loss: 0.076 accuracy: 1.000 eval: loss: 1.536 accuracy: 0.841 train: loss: 0.059 accuracy: 1.000 train: loss: 0.056 accuracy: 1.000 train: loss: 0.058 accuracy: 1.000 train: loss: 0.054 accuracy: 1.000 train: loss: 0.055 accuracy: 1.000 eval: loss: 1.497 accuracy: 0.863 train: loss: 0.053 accuracy: 1.000 train: loss: 0.049 accuracy: 1.000 train: loss: 0.044 accuracy: 1.000 train: loss: 0.049 accuracy: 1.000 train: loss: 0.045 accuracy: 1.000 eval: loss: 1.463 accuracy: 0.878
Customize the training step
If you need more flexibility and control, you can have it by implementing your own training loop. There are three steps:
- Iterate over a Python generator or
tf.data.Dataset
to get batches of examples. - Use
tf.GradientTape
to collect gradients. - Use one of the
tf.keras.optimizers
to apply weight updates to the model's variables.
Remember:
- Always include a
training
argument on thecall
method of subclassed layers and models. - Make sure to call the model with the
training
argument set correctly. - Depending on usage, model variables may not exist until the model is run on a batch of data.
- You need to manually handle things like regularization losses for the model.
Note the simplifications relative to v1:
- There is no need to run variable initializers. Variables are initialized on creation.
- There is no need to add manual control dependencies. Even in
tf.function
operations act as in eager mode.
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
for epoch in range(NUM_EPOCHS):
for inputs, labels in train_data:
train_step(inputs, labels)
print("Finished epoch", epoch)
Finished epoch 0 Finished epoch 1 Finished epoch 2 Finished epoch 3 Finished epoch 4
New-style metrics and losses
In TensorFlow 2.x, metrics and losses are objects. These work both eagerly and in tf.function
s.
A loss object is callable, and expects the (y_true, y_pred) as arguments:
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815
A metric object has the following methods:
Metric.update_state()
: add new observations.Metric.result()
: get the current result of the metric, given the observed values.Metric.reset_states()
: clear all observations.
The object itself is callable. Calling updates the state with new observations, as with update_state
, and returns the new result of the metric.
You don't have to manually initialize a metric's variables, and because TensorFlow 2.x has automatic control dependencies, you don't need to worry about those either.
The code below uses a metric to keep track of the mean loss observed within a custom training loop.
# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Update the metrics
loss_metric.update_state(total_loss)
accuracy_metric.update_state(labels, predictions)
for epoch in range(NUM_EPOCHS):
# Reset the metrics
loss_metric.reset_states()
accuracy_metric.reset_states()
for inputs, labels in train_data:
train_step(inputs, labels)
# Get the metric results
mean_loss=loss_metric.result()
mean_accuracy = accuracy_metric.result()
print('Epoch: ', epoch)
print(' loss: {:.3f}'.format(mean_loss))
print(' accuracy: {:.3f}'.format(mean_accuracy))
Epoch: 0 loss: 0.139 accuracy: 0.997 Epoch: 1 loss: 0.116 accuracy: 1.000 Epoch: 2 loss: 0.105 accuracy: 0.997 Epoch: 3 loss: 0.089 accuracy: 1.000 Epoch: 4 loss: 0.078 accuracy: 1.000
Keras metric names
In TensorFlow 2.x, Keras models are more consistent about handling metric names.
Now when you pass a string in the list of metrics, that exact string is used as the metric's name
. These names are visible in the history object returned by model.fit
, and in the logs passed to keras.callbacks
. is set to the string you passed in the metric list.
model.compile(
optimizer = tf.keras.optimizers.Adam(0.001),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 8ms/step - loss: 0.0901 - acc: 0.9923 - accuracy: 0.9923 - my_accuracy: 0.9923
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])
This differs from previous versions where passing metrics=["accuracy"]
would result in dict_keys(['loss', 'acc'])
Keras optimizers
The optimizers in v1.train
, such as v1.train.AdamOptimizer
and v1.train.GradientDescentOptimizer
, have equivalents in tf.keras.optimizers
.
Convert v1.train
to keras.optimizers
Here are things to keep in mind when converting your optimizers:
- Upgrading your optimizers may make old checkpoints incompatible.
- All epsilons now default to
1e-7
instead of1e-8
(which is negligible in most use cases). v1.train.GradientDescentOptimizer
can be directly replaced bytf.keras.optimizers.SGD
.v1.train.MomentumOptimizer
can be directly replaced by theSGD
optimizer using the momentum argument:tf.keras.optimizers.SGD(..., momentum=...)
.v1.train.AdamOptimizer
can be converted to usetf.keras.optimizers.Adam
. Thebeta1
andbeta2
arguments have been renamed tobeta_1
andbeta_2
.v1.train.RMSPropOptimizer
can be converted totf.keras.optimizers.RMSprop
. Thedecay
argument has been renamed torho
.v1.train.AdadeltaOptimizer
can be converted directly totf.keras.optimizers.Adadelta
.tf.train.AdagradOptimizer
can be converted directly totf.keras.optimizers.Adagrad
.tf.train.FtrlOptimizer
can be converted directly totf.keras.optimizers.Ftrl
. Theaccum_name
andlinear_name
arguments have been removed.tf.contrib.AdamaxOptimizer
andtf.contrib.NadamOptimizer
can be converted directly totf.keras.optimizers.Adamax
andtf.keras.optimizers.Nadam
, respectively. Thebeta1
, andbeta2
arguments have been renamed tobeta_1
andbeta_2
.
New defaults for some tf.keras.optimizers
There are no changes for optimizers.SGD
, optimizers.Adam
, or optimizers.RMSprop
.
The following default learning rates have changed:
optimizers.Adagrad
from 0.01 to 0.001optimizers.Adadelta
from 1.0 to 0.001optimizers.Adamax
from 0.002 to 0.001optimizers.Nadam
from 0.002 to 0.001
TensorBoard
TensorFlow 2.x includes significant changes to the tf.summary
API used to write summary data for visualization in TensorBoard. For a general introduction to the new tf.summary
, there are several tutorials available that use the TensorFlow 2.x API. This includes a TensorBoard TensorFlow 2.x migration guide.
Saving and loading
Checkpoint compatibility
TensorFlow 2.x uses object-based checkpoints.
Old-style name-based checkpoints can still be loaded, if you're careful. The code conversion process may result in variable name changes, but there are workarounds.
The simplest approach it to line up the names of the new model with the names in the checkpoint:
- Variables still all have a
name
argument you can set. - Keras models also take a
name
argument as which they set as the prefix for their variables. - The
v1.name_scope
function can be used to set variable name prefixes. This is very different fromtf.variable_scope
. It only affects names, and doesn't track variables and reuse.
If that does not work for your use case, try the v1.train.init_from_checkpoint
function. It takes an assignment_map
argument, which specifies the mapping from old names to new names.
The TensorFlow Estimator repository includes a conversion tool to upgrade the checkpoints for premade estimators from TensorFlow 1.x to 2.0. It may serve as an example of how to build a tool for a similar use case.
Saved models compatibility
There are no significant compatibility concerns for saved models.
- TensorFlow 1.x saved_models work in TensorFlow 2.x.
- TensorFlow 2.x saved_models work in TensorFlow 1.x if all the ops are supported.
A Graph.pb or Graph.pbtxt
There is no straightforward way to upgrade a raw Graph.pb
file to TensorFlow 2.x. Your best bet is to upgrade the code that generated the file.
But, if you have a "frozen graph" (a tf.Graph
where the variables have been turned into constants), then it is possible to convert this to a concrete_function
using v1.wrap_function
:
def wrap_frozen_graph(graph_def, inputs, outputs):
def _imports_graph_def():
tf.compat.v1.import_graph_def(graph_def, name="")
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
import_graph = wrapped_import.graph
return wrapped_import.prune(
tf.nest.map_structure(import_graph.as_graph_element, inputs),
tf.nest.map_structure(import_graph.as_graph_element, outputs))
For example, here is a frozed graph for Inception v1, from 2016:
path = tf.keras.utils.get_file(
'inception_v1_2016_08_28_frozen.pb',
'http://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz',
untar=True)
Downloading data from http://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz 24698880/24695710 [==============================] - 1s 0us/step
Load the tf.GraphDef
:
graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(open(path,'rb').read())
Wrap it into a concrete_function
:
inception_func = wrap_frozen_graph(
graph_def, inputs='input:0',
outputs='InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu:0')
Pass it a tensor as input:
input_img = tf.ones([1,224,224,3], dtype=tf.float32)
inception_func(input_img).shape
TensorShape([1, 28, 28, 96])
Estimators
Training with Estimators
Estimators are supported in TensorFlow 2.x.
When you use estimators, you can use input_fn
, tf.estimator.TrainSpec
, and tf.estimator.EvalSpec
from TensorFlow 1.x.
Here is an example using input_fn
with train and evaluate specs.
Creating the input_fn and train/eval specs
# Define the estimator's input_fn
def input_fn():
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
BUFFER_SIZE = 10000
BATCH_SIZE = 64
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label[..., tf.newaxis]
train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
return train_data.repeat()
# Define train and eval specs
train_spec = tf.estimator.TrainSpec(input_fn=input_fn,
max_steps=STEPS_PER_EPOCH * NUM_EPOCHS)
eval_spec = tf.estimator.EvalSpec(input_fn=input_fn,
steps=STEPS_PER_EPOCH)
Using a Keras model definition
There are some differences in how to construct your estimators in TensorFlow 2.x.
It's recommended that you define your model using Keras, then use the tf.keras.estimator.model_to_estimator
utility to turn your model into an estimator. The code below shows how to use this utility when creating and training an estimator.
def make_model():
return tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
model = make_model()
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
estimator = tf.keras.estimator.model_to_estimator(
keras_model = model
)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp0erq3im2 WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp0erq3im2 INFO:tensorflow:Using the Keras model provided. INFO:tensorflow:Using the Keras model provided. /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/backend.py:434: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model. warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and ' INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp0erq3im2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp0erq3im2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp0erq3im2/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp0erq3im2/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting from: /tmp/tmp0erq3im2/keras/keras_model.ckpt INFO:tensorflow:Warm-starting from: /tmp/tmp0erq3im2/keras/keras_model.ckpt INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-started 8 variables. INFO:tensorflow:Warm-started 8 variables. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp0erq3im2/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp0erq3im2/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 2.4717796, step = 0 INFO:tensorflow:loss = 2.4717796, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25... INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmp0erq3im2/model.ckpt. INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmp0erq3im2/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py:2325: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. warnings.warn('`Model.state_updates` will be removed in a future version. ' INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:17Z INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:17Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmp0erq3im2/model.ckpt-25 INFO:tensorflow:Restoring parameters from /tmp/tmp0erq3im2/model.ckpt-25 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/5] INFO:tensorflow:Evaluation [1/5] INFO:tensorflow:Evaluation [2/5] INFO:tensorflow:Evaluation [2/5] INFO:tensorflow:Evaluation [3/5] INFO:tensorflow:Evaluation [3/5] INFO:tensorflow:Evaluation [4/5] INFO:tensorflow:Evaluation [4/5] INFO:tensorflow:Evaluation [5/5] INFO:tensorflow:Evaluation [5/5] INFO:tensorflow:Inference Time : 0.86556s INFO:tensorflow:Inference Time : 0.86556s INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:18 INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:18 INFO:tensorflow:Saving dict for global step 25: accuracy = 0.6, global_step = 25, loss = 1.6160676 INFO:tensorflow:Saving dict for global step 25: accuracy = 0.6, global_step = 25, loss = 1.6160676 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmp0erq3im2/model.ckpt-25 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmp0erq3im2/model.ckpt-25 INFO:tensorflow:Loss for final step: 0.37597787. INFO:tensorflow:Loss for final step: 0.37597787. ({'accuracy': 0.6, 'loss': 1.6160676, 'global_step': 25}, [])
Using a custom model_fn
If you have an existing custom estimator model_fn
that you need to maintain, you can convert your model_fn
to use a Keras model.
However, for compatibility reasons, a custom model_fn
will still run in 1.x-style graph mode. This means there is no eager execution and no automatic control dependencies.
Custom model_fn with minimal changes
To make your custom model_fn
work in TensorFlow 2.x, if you prefer minimal changes to the existing code, tf.compat.v1
symbols such as optimizers
and metrics
can be used.
Using a Keras model in a custom model_fn
is similar to using it in a custom training loop:
- Set the
training
phase appropriately, based on themode
argument. - Explicitly pass the model's
trainable_variables
to the optimizer.
But there are important differences, relative to a custom loop:
- Instead of using
Model.losses
, extract the losses usingModel.get_losses_for
. - Extract the model's updates using
Model.get_updates_for
.
The following code creates an estimator from a custom model_fn
, illustrating all of these concerns.
def my_model_fn(features, labels, mode):
model = make_model()
optimizer = tf.compat.v1.train.AdamOptimizer()
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
training = (mode == tf.estimator.ModeKeys.TRAIN)
predictions = model(features, training=training)
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
reg_losses = model.get_losses_for(None) + model.get_losses_for(features)
total_loss=loss_fn(labels, predictions) + tf.math.add_n(reg_losses)
accuracy = tf.compat.v1.metrics.accuracy(labels=labels,
predictions=tf.math.argmax(predictions, axis=1),
name='acc_op')
update_ops = model.get_updates_for(None) + model.get_updates_for(features)
minimize_op = optimizer.minimize(
total_loss,
var_list=model.trainable_variables,
global_step=tf.compat.v1.train.get_or_create_global_step())
train_op = tf.group(minimize_op, update_ops)
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=total_loss,
train_op=train_op, eval_metric_ops={'accuracy': accuracy})
# Create the Estimator & Train
estimator = tf.estimator.Estimator(model_fn=my_model_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpifj8mysl WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpifj8mysl INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpifj8mysl', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpifj8mysl', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpifj8mysl/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpifj8mysl/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 3.0136237, step = 0 INFO:tensorflow:loss = 3.0136237, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25... INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmpifj8mysl/model.ckpt. INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmpifj8mysl/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:20Z INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:20Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpifj8mysl/model.ckpt-25 INFO:tensorflow:Restoring parameters from /tmp/tmpifj8mysl/model.ckpt-25 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/5] INFO:tensorflow:Evaluation [1/5] INFO:tensorflow:Evaluation [2/5] INFO:tensorflow:Evaluation [2/5] INFO:tensorflow:Evaluation [3/5] INFO:tensorflow:Evaluation [3/5] INFO:tensorflow:Evaluation [4/5] INFO:tensorflow:Evaluation [4/5] INFO:tensorflow:Evaluation [5/5] INFO:tensorflow:Evaluation [5/5] INFO:tensorflow:Inference Time : 0.97406s INFO:tensorflow:Inference Time : 0.97406s INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:21 INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:21 INFO:tensorflow:Saving dict for global step 25: accuracy = 0.59375, global_step = 25, loss = 1.6248872 INFO:tensorflow:Saving dict for global step 25: accuracy = 0.59375, global_step = 25, loss = 1.6248872 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmpifj8mysl/model.ckpt-25 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmpifj8mysl/model.ckpt-25 INFO:tensorflow:Loss for final step: 0.35726172. INFO:tensorflow:Loss for final step: 0.35726172. ({'accuracy': 0.59375, 'loss': 1.6248872, 'global_step': 25}, [])
Custom model_fn
with TensorFlow 2.x symbols
If you want to get rid of all TensorFlow 1.x symbols and upgrade your custom model_fn
to TensorFlow 2.x, you need to update the optimizer and metrics to tf.keras.optimizers
and tf.keras.metrics
.
In the custom model_fn
, besides the above changes, more upgrades need to be made:
- Use
tf.keras.optimizers
instead ofv1.train.Optimizer
. - Explicitly pass the model's
trainable_variables
to thetf.keras.optimizers
. - To compute the
train_op/minimize_op
,- Use
Optimizer.get_updates
if the loss is scalar lossTensor
(not a callable). The first element in the returned list is the desiredtrain_op/minimize_op
. - If the loss is a callable (such as a function), use
Optimizer.minimize
to get thetrain_op/minimize_op
.
- Use
- Use
tf.keras.metrics
instead oftf.compat.v1.metrics
for evaluation.
For the above example of my_model_fn
, the migrated code with TensorFlow 2.x symbols is shown as:
def my_model_fn(features, labels, mode):
model = make_model()
training = (mode == tf.estimator.ModeKeys.TRAIN)
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
predictions = model(features, training=training)
# Get both the unconditional losses (the None part)
# and the input-conditional losses (the features part).
reg_losses = model.get_losses_for(None) + model.get_losses_for(features)
total_loss=loss_obj(labels, predictions) + tf.math.add_n(reg_losses)
# Upgrade to tf.keras.metrics.
accuracy_obj = tf.keras.metrics.Accuracy(name='acc_obj')
accuracy = accuracy_obj.update_state(
y_true=labels, y_pred=tf.math.argmax(predictions, axis=1))
train_op = None
if training:
# Upgrade to tf.keras.optimizers.
optimizer = tf.keras.optimizers.Adam()
# Manually assign tf.compat.v1.global_step variable to optimizer.iterations
# to make tf.compat.v1.train.global_step increased correctly.
# This assignment is a must for any `tf.train.SessionRunHook` specified in
# estimator, as SessionRunHooks rely on global step.
optimizer.iterations = tf.compat.v1.train.get_or_create_global_step()
# Get both the unconditional updates (the None part)
# and the input-conditional updates (the features part).
update_ops = model.get_updates_for(None) + model.get_updates_for(features)
# Compute the minimize_op.
minimize_op = optimizer.get_updates(
total_loss,
model.trainable_variables)[0]
train_op = tf.group(minimize_op, *update_ops)
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=total_loss,
train_op=train_op,
eval_metric_ops={'Accuracy': accuracy_obj})
# Create the Estimator and train.
estimator = tf.estimator.Estimator(model_fn=my_model_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpc93qfnv6 WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpc93qfnv6 INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpc93qfnv6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpc93qfnv6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpc93qfnv6/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpc93qfnv6/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 2.5293791, step = 0 INFO:tensorflow:loss = 2.5293791, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25... INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmpc93qfnv6/model.ckpt. INFO:tensorflow:Saving checkpoints for 25 into /tmp/tmpc93qfnv6/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:24Z INFO:tensorflow:Starting evaluation at 2021-01-06T02:31:24Z INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpc93qfnv6/model.ckpt-25 INFO:tensorflow:Restoring parameters from /tmp/tmpc93qfnv6/model.ckpt-25 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/5] INFO:tensorflow:Evaluation [1/5] INFO:tensorflow:Evaluation [2/5] INFO:tensorflow:Evaluation [2/5] INFO:tensorflow:Evaluation [3/5] INFO:tensorflow:Evaluation [3/5] INFO:tensorflow:Evaluation [4/5] INFO:tensorflow:Evaluation [4/5] INFO:tensorflow:Evaluation [5/5] INFO:tensorflow:Evaluation [5/5] INFO:tensorflow:Inference Time : 0.86534s INFO:tensorflow:Inference Time : 0.86534s INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:25 INFO:tensorflow:Finished evaluation at 2021-01-06-02:31:25 INFO:tensorflow:Saving dict for global step 25: Accuracy = 0.59375, global_step = 25, loss = 1.7570661 INFO:tensorflow:Saving dict for global step 25: Accuracy = 0.59375, global_step = 25, loss = 1.7570661 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmpc93qfnv6/model.ckpt-25 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tmpc93qfnv6/model.ckpt-25 INFO:tensorflow:Loss for final step: 0.47094986. INFO:tensorflow:Loss for final step: 0.47094986. ({'Accuracy': 0.59375, 'loss': 1.7570661, 'global_step': 25}, [])
Premade Estimators
Premade Estimators in the family of tf.estimator.DNN*
, tf.estimator.Linear*
and tf.estimator.DNNLinearCombined*
are still supported in the TensorFlow 2.x API. However, some arguments have changed:
input_layer_partitioner
: Removed in v2.loss_reduction
: Updated totf.keras.losses.Reduction
instead oftf.compat.v1.losses.Reduction
. Its default value is also changed totf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
fromtf.compat.v1.losses.Reduction.SUM
.optimizer
,dnn_optimizer
andlinear_optimizer
: this argument has been updated totf.keras.optimizers
instead of thetf.compat.v1.train.Optimizer
.
To migrate the above changes:
- No migration is needed for
input_layer_partitioner
sinceDistribution Strategy
will handle it automatically in TensorFlow 2.x. - For
loss_reduction
, checktf.keras.losses.Reduction
for the supported options. - For
optimizer
arguments:- If you do not: 1) pass in the
optimizer
,dnn_optimizer
orlinear_optimizer
argument, or 2) specify theoptimizer
argument as astring
in your code, then you don't need to change anything becausetf.keras.optimizers
is used by default. - Otherwise, you need to update it from
tf.compat.v1.train.Optimizer
to its correspondingtf.keras.optimizers
.
- If you do not: 1) pass in the
Checkpoint Converter
The migration to keras.optimizers
will break checkpoints saved using TensorFlow 1.x, as tf.keras.optimizers
generates a different set of variables to be saved in checkpoints. To make old checkpoint reusable after your migration to TensorFlow 2.x, try the checkpoint converter tool.
curl -O https://raw.githubusercontent.com/tensorflow/estimator/master/tensorflow_estimator/python/estimator/tools/checkpoint_converter.py
% Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 15165 100 15165 0 0 40656 0 --:--:-- --:--:-- --:--:-- 40656
The tool has built-in help:
python checkpoint_converter.py -h
2021-01-06 02:31:26.297951: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0 usage: checkpoint_converter.py [-h] {dnn,linear,combined} source_checkpoint source_graph target_checkpoint positional arguments: {dnn,linear,combined} The type of estimator to be converted. So far, the checkpoint converter only supports Canned Estimator. So the allowed types include linear, dnn and combined. source_checkpoint Path to source checkpoint file to be read in. source_graph Path to source graph file to be read in. target_checkpoint Path to checkpoint file to be written out. optional arguments: -h, --help show this help message and exit
TensorShape
This class was simplified to hold int
s, instead of tf.compat.v1.Dimension
objects. So there is no need to call .value
to get an int
.
Individual tf.compat.v1.Dimension
objects are still accessible from tf.TensorShape.dims
.
The following demonstrate the differences between TensorFlow 1.x and TensorFlow 2.x.
# Create a shape and choose an index
i = 0
shape = tf.TensorShape([16, None, 256])
shape
TensorShape([16, None, 256])
If you had this in TensorFlow 1.x:
value = shape[i].value
Then do this in TensorFlow 2.x:
value = shape[i]
value
16
If you had this in TensorFlow 1.x:
for dim in shape:
value = dim.value
print(value)
Then do this in TensorFlow 2.x:
for value in shape:
print(value)
16 None 256
If you had this in TensorFlow 1.x (or used any other dimension method):
dim = shape[i]
dim.assert_is_compatible_with(other_dim)
Then do this in TensorFlow 2.x:
other_dim = 16
Dimension = tf.compat.v1.Dimension
if shape.rank is None:
dim = Dimension(None)
else:
dim = shape.dims[i]
dim.is_compatible_with(other_dim) # or any other dimension method
True
shape = tf.TensorShape(None)
if shape:
dim = shape.dims[i]
dim.is_compatible_with(other_dim) # or any other dimension method
The boolean value of a tf.TensorShape
is True
if the rank is known, False
otherwise.
print(bool(tf.TensorShape([]))) # Scalar
print(bool(tf.TensorShape([0]))) # 0-length vector
print(bool(tf.TensorShape([1]))) # 1-length vector
print(bool(tf.TensorShape([None]))) # Unknown-length vector
print(bool(tf.TensorShape([1, 10, 100]))) # 3D tensor
print(bool(tf.TensorShape([None, None, None]))) # 3D tensor with no known dimensions
print()
print(bool(tf.TensorShape(None))) # A tensor with unknown rank.
True True True True True True False
Other changes
Remove
tf.colocate_with
: TensorFlow's device placement algorithms have improved significantly. This should no longer be necessary. If removing it causes a performance degredation please file a bug.Replace
v1.ConfigProto
usage with the equivalent functions fromtf.config
.
Conclusions
The overall process is:
- Run the upgrade script.
- Remove contrib symbols.
- Switch your models to an object oriented style (Keras).
- Use
tf.keras
ortf.estimator
training and evaluation loops where you can. - Otherwise, use custom loops, but be sure to avoid sessions & collections.
It takes a little work to convert code to idiomatic TensorFlow 2.x, but every change results in:
- Fewer lines of code.
- Increased clarity and simplicity.
- Easier debugging.