Attend the Women in ML Symposium on December 7 Register now

Save and load a model using a distribution strategy

Stay organized with collections Save and categorize content based on your preferences.

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

This tutorial demonstrates how you can save and load models in a SavedModel format with tf.distribute.Strategy during or after training. There are two kinds of APIs for saving and loading a Keras model: high-level (tf.keras.Model.save and tf.keras.models.load_model) and low-level (tf.saved_model.save and tf.saved_model.load).

To learn about SavedModel and serialization in general, please read the saved model guide, and the Keras model serialization guide. Let's start with a simple example.

Import dependencies:

import tensorflow_datasets as tfds

import tensorflow as tf
2022-10-20 04:28:35.267627: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-10-20 04:28:35.267741: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-10-20 04:28:35.267753: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Load and prepare the data with TensorFlow Datasets and tf.data, and create the model using tf.distribute.MirroredStrategy:

mirrored_strategy = tf.distribute.MirroredStrategy()

def get_data():
  datasets = tfds.load(name='mnist', as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000

  BATCH_SIZE_PER_REPLICA = 64
  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label

  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

  return train_dataset, eval_dataset

def get_model():
  with mirrored_strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=[tf.metrics.SparseCategoricalAccuracy()])
    return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

Train the model with tf.keras.Model.fit:

model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
2022-10-20 04:28:42.672422: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:549] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
235/235 [==============================] - 13s 12ms/step - loss: 0.3582 - sparse_categorical_accuracy: 0.8982
Epoch 2/2
235/235 [==============================] - 2s 10ms/step - loss: 0.1315 - sparse_categorical_accuracy: 0.9621
<keras.callbacks.History at 0x7fa41584ffd0>

Save and load the model

Now that you have a simple model to work with, let's explore the saving/loading APIs. There are two kinds of APIs available:

The Keras API

Here is an example of saving and loading a model with the Keras API:

keras_model_path = '/tmp/keras_save'
model.save(keras_model_path)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets

Restore the model without tf.distribute.Strategy:

restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2
235/235 [==============================] - 2s 5ms/step - loss: 0.0864 - sparse_categorical_accuracy: 0.9751
Epoch 2/2
235/235 [==============================] - 1s 5ms/step - loss: 0.0616 - sparse_categorical_accuracy: 0.9829
<keras.callbacks.History at 0x7fa3ec314e50>

After restoring the model, you can continue training on it, even without needing to call Model.compile again, since it was already compiled before saving. The model is saved in TensorFlow's standard SavedModel proto format. For more information, please refer to the guide to SavedModel format.

Now, restore the model and train it using a tf.distribute.Strategy:

another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
with another_strategy.scope():
  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
  restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2
2022-10-20 04:29:03.247887: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:549] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
2022-10-20 04:29:03.305683: W tensorflow/core/framework/dataset.cc:769] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
235/235 [==============================] - 3s 12ms/step - loss: 0.0883 - sparse_categorical_accuracy: 0.9747
Epoch 2/2
235/235 [==============================] - 3s 12ms/step - loss: 0.0638 - sparse_categorical_accuracy: 0.9817

As the Model.fit output shows, loading works as expected with tf.distribute.Strategy. The strategy used here does not have to be the same strategy used before saving.

The tf.saved_model API

Saving the model with lower-level API is similar to the Keras API:

model = get_model()  # get a fresh model
saved_model_path = '/tmp/tf_save'
tf.saved_model.save(model, saved_model_path)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _update_step_xla while saving (showing 2 of 2). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets

Loading can be done with tf.saved_model.load. However, since it is a lower-level API (and hence has a wider range of use cases), it does not return a Keras model. Instead, it returns an object that contain functions that can be used to do inference. For example:

DEFAULT_FUNCTION_KEY = 'serving_default'
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

The loaded object may contain multiple functions, each associated with a key. The "serving_default" key is the default key for the inference function with a saved Keras model. To do inference with this function:

predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
  print(inference_func(batch))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
{'dense_3': <tf.Tensor: shape=(256, 10), dtype=float32, numpy=
array([[ 0.05669541, -0.18249513,  0.09468295, ..., -0.15790027,
         0.16903076, -0.29186207],
       [-0.18513727, -0.17730473,  0.00543143, ...,  0.00501085,
         0.09462018, -0.14849868],
       [ 0.00233667, -0.07392745,  0.17142265, ...,  0.02672801,
         0.11125134, -0.1890912 ],
       ...,
       [-0.06212103, -0.04520793,  0.11878827, ..., -0.02461467,
        -0.10978805, -0.18873338],
       [-0.04277811, -0.08950038,  0.15964158, ..., -0.02872062,
        -0.07654422, -0.19493806],
       [-0.10444203,  0.00178434,  0.09272549, ..., -0.02228521,
         0.03169406, -0.27460656]], dtype=float32)>}
2022-10-20 04:29:10.960535: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

You can also load and do inference in a distributed manner:

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]

  dist_predict_dataset = another_strategy.experimental_distribute_dataset(
      predict_dataset)

  # Calling the function in a distributed manner
  for batch in dist_predict_dataset:
    result = another_strategy.run(inference_func, args=(batch,))
    print(result)
    break
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
2022-10-20 04:29:11.204492: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:549] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
{'dense_3': PerReplica:{
  0: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[ 5.66954128e-02, -1.82495132e-01,  9.46829543e-02,
        -1.99705213e-02,  6.43105432e-03,  1.92189962e-03,
         5.45716546e-02, -1.57900274e-01,  1.69030756e-01,
        -2.91862071e-01],
       [-1.85137272e-01, -1.77304730e-01,  5.43142855e-03,
        -6.61219284e-02,  1.38096586e-02,  1.42965347e-01,
         1.21411182e-01,  5.01085073e-03,  9.46201831e-02,
        -1.48498684e-01],
       [ 2.33667158e-03, -7.39274472e-02,  1.71422645e-01,
         3.61611433e-02, -1.20555192e-01,  7.41227418e-02,
        -2.43873000e-02,  2.67280117e-02,  1.11251339e-01,
        -1.89091206e-01],
       [ 7.09555298e-02, -2.69076303e-02,  1.43073052e-01,
         4.84543145e-02, -5.96515387e-02,  2.92756185e-02,
         6.04203120e-02, -3.50991935e-02,  4.28772122e-02,
        -2.49793217e-01],
       [-4.17066738e-04, -1.53508514e-01,  5.01468517e-02,
         6.86274320e-02, -8.20913985e-02,  2.90331282e-02,
        -6.45995438e-02, -1.36710852e-02, -4.28515375e-02,
        -1.45602733e-01],
       [-2.55423039e-03, -7.22648352e-02,  1.55540943e-01,
        -5.50822914e-02, -1.31751090e-01, -1.63858384e-03,
         1.99921772e-01, -3.41505967e-02,  6.90993071e-02,
        -3.22329402e-01],
       [-6.29692525e-02, -8.97325203e-02, -6.68896735e-03,
        -1.18064523e-01, -6.81343600e-02,  9.87244695e-02,
         1.73352793e-01, -2.41610222e-02,  1.02978230e-01,
        -1.85825273e-01],
       [ 5.07793203e-03,  6.07362017e-03,  1.56539544e-01,
        -5.70530705e-02, -9.69728827e-02,  8.46912861e-02,
         1.58232242e-01,  4.53494489e-04,  4.70230728e-03,
        -3.04220438e-01],
       [-2.83630956e-02,  3.63154262e-02,  1.10633574e-01,
        -1.28635436e-01, -7.21479729e-02,  9.34774727e-02,
         1.74799636e-02,  9.85393301e-04,  1.85287148e-01,
        -3.18643719e-01],
       [ 4.78376746e-02, -1.69033408e-02,  1.90827817e-01,
        -4.44185548e-02, -5.79027236e-02, -4.17400151e-04,
        -1.98327415e-02,  6.94948584e-02, -1.66194886e-02,
        -1.47644699e-01],
       [-7.93924630e-02, -3.12273428e-02, -3.22259963e-03,
        -5.49254753e-03, -6.27758950e-02,  8.76613259e-02,
         1.27670616e-02, -2.20820718e-02, -6.28538951e-02,
        -2.05850273e-01],
       [-2.27267474e-01, -1.35238484e-01,  6.08137548e-02,
         9.79769379e-02, -8.80409926e-02,  1.32549465e-01,
         1.42745063e-01,  5.52276224e-02,  1.12522028e-01,
        -3.31825376e-01],
       [-2.93141864e-02, -1.29181780e-02,  4.71726619e-02,
        -3.17186154e-02, -1.19468346e-01,  1.67779624e-03,
         5.74057996e-02, -6.56485185e-03, -1.03122763e-01,
        -1.71964526e-01],
       [-9.45570506e-03, -1.95114464e-01,  9.45393592e-02,
         5.26209176e-03, -1.00912251e-01,  9.82980244e-03,
         8.42497200e-02, -7.55228847e-02, -1.45815969e-01,
        -1.82608530e-01],
       [ 6.74009994e-02, -1.15426816e-01,  1.04768574e-01,
         2.23918781e-02, -8.85902569e-02, -5.18909469e-03,
         7.35665709e-02, -1.30406946e-01,  6.36447147e-02,
        -2.72003770e-01],
       [-2.85958424e-02, -2.33362317e-01,  1.58907712e-01,
         1.51008368e-02, -1.30212963e-01,  2.78012529e-02,
        -7.83246458e-02,  9.52788889e-02,  3.85322422e-03,
        -2.03781515e-01],
       [-1.53383777e-01, -1.04373753e-01,  1.62467621e-02,
         1.32804960e-01, -4.80064936e-03,  1.30110562e-01,
        -5.72463833e-02, -1.78163107e-02,  2.08516419e-02,
        -1.04129411e-01],
       [-1.25490487e-01, -6.54555112e-02,  9.26738381e-02,
         1.27485879e-02, -1.02886081e-01,  4.75829691e-02,
         7.27233738e-02,  5.86150400e-02,  5.61207533e-04,
        -2.66564548e-01],
       [-5.64045273e-02, -2.21078880e-02,  9.45080072e-02,
        -7.43014924e-03,  4.68149781e-03,  1.49590567e-01,
         1.38935968e-02, -2.10650750e-02, -5.82197495e-02,
        -1.36843339e-01],
       [-5.81762306e-02, -2.32888475e-01,  1.20803401e-01,
         1.58917397e-01, -8.55342597e-02,  6.53628409e-02,
         4.51972634e-02, -2.94546261e-02, -3.35515365e-02,
        -2.02911869e-01],
       [-2.44128630e-02, -6.61968440e-02,  9.65379775e-02,
        -4.31480072e-03, -5.02166487e-02,  5.63264117e-02,
         1.22770034e-02, -5.41117527e-02,  3.23219001e-02,
        -1.75194770e-01],
       [-5.33946753e-02, -5.08727357e-02,  8.76996517e-02,
         2.72703916e-03, -2.84151919e-02,  3.44137624e-02,
         3.10903527e-02, -2.93710288e-02, -2.77243294e-02,
        -1.36631057e-01],
       [-1.21435203e-01,  9.37119052e-02,  3.48305218e-02,
        -2.92496681e-02, -1.14468634e-02,  1.75158754e-01,
        -8.28521326e-03,  4.49379385e-02,  1.32546052e-02,
        -2.53419429e-01],
       [-1.88333899e-01, -1.99804269e-02, -1.21428385e-01,
        -1.10797033e-01,  3.04928795e-02,  1.37095854e-01,
         1.91499770e-01, -1.02875799e-01, -2.36042365e-02,
        -1.19434297e-01],
       [-1.85019344e-01,  1.55878931e-01,  8.07486326e-02,
         7.33169913e-03,  2.91078798e-02,  1.87293380e-01,
         1.18885748e-01,  2.42630318e-02,  1.88828968e-02,
        -1.77141190e-01],
       [-1.52159855e-01, -3.32419351e-02,  1.37156680e-01,
         2.41492987e-02, -4.36935648e-02,  1.04925282e-01,
         5.74045628e-02, -5.16976900e-02,  6.61963001e-02,
        -2.87556529e-01],
       [-6.70766234e-02, -5.09437099e-02,  1.73444226e-01,
        -2.09107846e-02, -2.01174840e-02,  8.07069987e-03,
        -1.79876499e-02,  9.50533524e-03,  6.87135160e-02,
        -1.29187614e-01],
       [-8.46410394e-02, -1.02582596e-01,  9.37556624e-02,
        -5.14442399e-02, -6.98624849e-02,  7.95813352e-02,
         1.04995184e-01, -5.54239899e-02,  4.67140377e-02,
        -1.75024226e-01],
       [ 7.93044567e-02, -9.02793258e-02,  1.23429924e-01,
         8.33548382e-02, -5.20040654e-02,  2.00665668e-02,
        -1.10023186e-01, -4.70002890e-02, -1.00593589e-01,
        -1.79287806e-01],
       [-6.11634329e-02, -4.47400436e-02, -1.54428184e-02,
         4.97855060e-02, -9.75240022e-04,  1.46307319e-01,
         9.99712199e-02, -1.56616896e-01, -1.06521934e-01,
        -2.91525215e-01],
       [-1.43314898e-03, -1.39404997e-01,  1.18955813e-01,
         4.79997024e-02, -8.63237157e-02,  3.22298631e-02,
         1.05759501e-01,  6.32373989e-03, -2.79504433e-02,
        -2.37498477e-01],
       [-1.12238787e-01, -1.80670973e-02,  1.14465959e-01,
        -5.34716919e-02, -3.40092182e-02,  9.78776217e-02,
        -1.58212185e-02,  6.65269047e-02,  1.01974718e-02,
        -1.49566352e-01],
       [-8.87611955e-02, -5.45325130e-02,  9.51355845e-02,
         4.67952341e-02, -6.70888498e-02,  9.37887728e-02,
         1.00791454e-01, -7.00738728e-02,  1.10674225e-01,
        -3.28121126e-01],
       [ 8.33723135e-03,  1.98493712e-02,  9.12100598e-02,
        -6.89653903e-02,  2.48191021e-02,  2.54774578e-02,
         4.05964777e-02,  2.36025676e-02, -1.27023086e-02,
        -1.17614605e-01],
       [-7.08008781e-02, -6.59430102e-02,  7.59359002e-02,
        -1.68799698e-01,  1.92671679e-02,  1.12590134e-01,
         1.25862032e-01, -1.53864287e-02,  5.55155948e-02,
        -1.16730958e-01],
       [ 1.00635722e-01, -1.19294055e-01,  5.16603887e-02,
         2.11694650e-02, -4.89443839e-02,  6.89193234e-03,
        -1.68563873e-01, -1.89470053e-02,  1.29722089e-01,
        -3.26810122e-01],
       [-1.75078169e-01, -9.99596342e-02,  4.43945155e-02,
         1.00126229e-02, -5.74500784e-02,  1.99933611e-02,
        -9.39644128e-03, -1.67780761e-02,  1.87414587e-01,
        -1.79478452e-01],
       [-1.49553806e-01, -1.35560110e-01, -6.84016943e-02,
         1.40467748e-01, -8.39289352e-02,  3.61553729e-02,
        -1.05057806e-02, -1.05920210e-02, -3.50173712e-02,
        -2.35617906e-01],
       [-5.40115871e-02, -4.09744903e-02,  8.92305970e-02,
         4.70547751e-03, -7.94834718e-02,  3.28739956e-02,
         1.57701850e-01, -9.10551324e-02, -1.11330733e-01,
        -1.81056350e-01],
       [ 3.78102027e-02, -1.67116299e-02,  1.79950029e-01,
         9.75334868e-02, -6.49004728e-02,  5.69058917e-02,
        -1.51514728e-02, -2.47532278e-02, -2.35725157e-02,
        -3.28668892e-01],
       [-5.54820262e-02, -2.46991515e-02,  6.07664809e-02,
        -9.02032703e-02,  1.62578449e-02,  7.21036196e-02,
         1.39213428e-01, -4.82221544e-02, -1.82645731e-02,
        -2.19011664e-01],
       [-6.74669221e-02, -4.45094369e-02,  5.06065749e-02,
         6.86510429e-02, -1.46946803e-01,  4.71791700e-02,
         1.32827669e-01, -1.46602884e-01,  8.28201622e-02,
        -3.13783944e-01],
       [-7.81435966e-02,  3.42660248e-02,  9.46574956e-02,
        -2.42023543e-02, -3.82454656e-02,  1.53056592e-01,
         5.33029065e-02, -5.85295521e-02, -3.68238762e-02,
        -2.02501923e-01],
       [-1.03744939e-01, -8.51878375e-02,  6.86752424e-02,
         7.63596594e-02, -1.40484184e-01,  2.11760756e-02,
         5.41022643e-02, -5.48546277e-02, -3.20180133e-03,
        -2.77471662e-01],
       [ 5.94272017e-02, -1.39891207e-01,  7.96461850e-02,
         1.10434555e-02, -1.06250145e-01,  1.57188959e-02,
        -2.42327899e-02, -7.40136355e-02,  2.01686714e-02,
        -2.51626670e-01],
       [-1.42513454e-01,  2.68029142e-02,  1.19365364e-01,
         4.12509404e-02, -5.55558354e-02,  1.56092569e-02,
         9.64010879e-02, -7.82285109e-02,  1.19011290e-01,
        -2.97961354e-01],
       [ 3.45030650e-02, -3.45892087e-02,  6.09124228e-02,
        -1.91001333e-02, -1.05113477e-01, -3.12444307e-02,
         6.26624674e-02, -4.75284569e-02, -5.44013381e-02,
        -2.20760778e-01],
       [-1.79890007e-01, -1.44872025e-01,  4.31203991e-02,
         4.92147729e-03, -5.80936521e-02,  1.42654926e-01,
         7.77525678e-02,  5.14203422e-02,  1.26159608e-01,
        -2.70169139e-01],
       [ 5.04556298e-03, -5.98378852e-02,  1.63483873e-01,
         4.61601987e-02, -1.07411996e-01,  5.81188761e-02,
         1.62896648e-01, -2.30355822e-02, -1.16571233e-01,
        -2.80855149e-01],
       [ 5.90390265e-02, -9.50526968e-02,  1.81160837e-01,
         2.80606691e-02, -1.25055060e-01, -3.32352780e-02,
         1.85216293e-02, -1.10664023e-02, -1.38730928e-02,
        -1.71855092e-01],
       [-1.62756518e-01, -3.99812236e-02, -6.31950200e-02,
        -1.31074190e-01, -5.21892570e-02,  9.43032727e-02,
         1.28931955e-01, -2.63837744e-02,  1.19855441e-02,
        -2.66812205e-01],
       [ 1.20649636e-02, -3.76409069e-02,  4.67889421e-02,
        -7.14131594e-02, -7.90149495e-02, -2.63769627e-02,
         5.84802628e-02, -5.48603460e-02, -1.05015859e-01,
        -1.54432252e-01],
       [-7.34755546e-02, -5.33646569e-02,  4.08975035e-02,
         2.27992460e-02,  7.31336698e-03,  8.57038721e-02,
         8.91994983e-02, -7.19701797e-02, -4.47044522e-02,
        -1.51256025e-01],
       [ 8.94926041e-02, -5.79232164e-02,  6.35098666e-02,
        -8.85612369e-02, -1.53377205e-02,  4.78810742e-02,
         8.49968046e-02, -8.07519779e-02, -1.28753074e-02,
        -2.37873405e-01],
       [ 8.25932622e-02, -8.05639699e-02,  1.00028314e-01,
        -1.84262134e-02, -8.89936388e-02,  5.55495061e-02,
         9.58590284e-02, -1.36321038e-02, -2.24590600e-02,
        -3.06535244e-01],
       [-1.04582205e-01, -8.79977196e-02,  6.66528419e-02,
        -5.44024371e-02, -6.58023953e-02,  6.88966513e-02,
         5.55252656e-02, -2.35139504e-02,  1.92917362e-02,
        -1.89358160e-01],
       [-6.86862394e-02, -1.06500089e-03,  6.03844151e-02,
        -5.43629937e-02, -3.32955569e-02,  1.95587963e-01,
         3.79181206e-02,  5.10626063e-02,  3.71694379e-03,
        -1.08448029e-01],
       [-3.07790693e-02, -1.16294108e-01,  1.66772336e-01,
         3.12045217e-04, -7.46974349e-02,  3.31896991e-02,
         9.19365063e-02,  1.91757381e-02,  5.70178032e-03,
        -3.22825909e-01],
       [ 6.84231445e-02, -6.62634298e-02,  5.36312386e-02,
        -6.94844574e-02,  2.14290619e-02, -2.27747187e-02,
         9.34493542e-03,  4.43296544e-02,  1.54955119e-01,
        -2.51641929e-01],
       [-1.05203688e-01, -9.24872383e-02, -4.02869768e-02,
        -6.71340078e-02,  4.58940417e-02,  1.34631842e-01,
        -3.86651754e-02, -9.74073634e-02,  6.26467690e-02,
        -1.60101205e-01],
       [-2.90776975e-02, -3.54301091e-03,  5.10245413e-02,
        -1.14471629e-01, -8.38540345e-02,  1.16656668e-01,
         2.61977017e-02, -7.31423274e-02,  1.44772589e-01,
        -1.88183963e-01],
       [-5.39964586e-02, -8.68393481e-02,  1.88312307e-02,
        -4.57168110e-02, -1.41874388e-01,  7.88870305e-02,
         6.16258234e-02,  2.69923508e-02,  1.50442541e-01,
        -3.74336541e-01],
       [-1.06885269e-01, -1.82386667e-01,  5.96784689e-02,
         1.78860605e-01, -5.11524826e-02,  6.52992427e-02,
         4.16932181e-02, -4.84792665e-02, -9.35789868e-02,
        -1.00470439e-01],
       [-2.52255425e-02, -1.60657242e-02,  1.27352715e-01,
         2.54525617e-03,  1.30776465e-02,  1.02565028e-01,
        -5.63937314e-02,  7.80967027e-02,  8.27538744e-02,
        -1.68258443e-01]], dtype=float32)>,
  1: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-0.15557952,  0.02978121,  0.05786635,  0.05029424, -0.08100337,
         0.08485792,  0.01067205,  0.01930448,  0.17535727, -0.3571884 ],
       [-0.09972841, -0.11503974, -0.01885801, -0.01026693, -0.04118045,
         0.06975561,  0.06720647, -0.10413025, -0.0187361 , -0.15885758],
       [-0.02432652, -0.12444085,  0.09186121,  0.06342544, -0.08932767,
         0.05336151, -0.01659345,  0.06391826, -0.00066344, -0.3844564 ],
       [ 0.00485029,  0.01910353,  0.05061828, -0.07543425, -0.03801598,
         0.03095831,  0.06012974,  0.00190337, -0.07615224, -0.15562594],
       [-0.18798403, -0.14934172, -0.08116613,  0.00228402, -0.11514428,
         0.05111733,  0.09566309, -0.04067147,  0.05310358, -0.1943738 ],
       [-0.10309469, -0.06157941,  0.18939754,  0.00496348, -0.01965233,
        -0.021002  ,  0.02353965,  0.03004068,  0.077181  , -0.16324136],
       [-0.05208187, -0.09821104,  0.08761825,  0.06752673, -0.04673455,
         0.09131843,  0.0447525 , -0.04742849, -0.06361099, -0.09361313],
       [-0.16430864, -0.03272498,  0.04944088,  0.02026294, -0.04837731,
         0.05755142,  0.05726802,  0.00664935, -0.03246859, -0.15461552],
       [ 0.04470079, -0.07252423,  0.1053823 , -0.10679249, -0.08287175,
         0.07955782,  0.16090119, -0.11057827, -0.00176555, -0.26959816],
       [ 0.08050299, -0.16957788,  0.11249913,  0.05064991, -0.13236976,
        -0.03783875,  0.06070975, -0.13841623,  0.03129433, -0.34678075],
       [-0.06692423,  0.03647283,  0.07408667, -0.08492438,  0.00222062,
         0.0713606 ,  0.04264574,  0.04721164, -0.02030249, -0.10677683],
       [-0.10353575, -0.0176539 ,  0.15198028,  0.00679957, -0.04178385,
         0.04443333,  0.06556588,  0.02183904,  0.06406362, -0.23785596],
       [-0.02945594, -0.14488016,  0.17099538,  0.1399909 , -0.17194632,
        -0.01879754, -0.13487527,  0.05174273,  0.12912057, -0.21676761],
       [-0.14953718, -0.06488447,  0.14279692,  0.03902669, -0.14709866,
         0.10345769,  0.20537452, -0.03420053, -0.07028738, -0.2568105 ],
       [ 0.08521292, -0.0621441 ,  0.19209877,  0.05634103, -0.15210289,
        -0.06890653, -0.07356648, -0.05210365,  0.03922088, -0.1925191 ],
       [-0.01990771, -0.07971789,  0.08406444,  0.0081448 ,  0.00640431,
         0.07728601,  0.08533095, -0.06256017, -0.16131699, -0.08592564],
       [-0.1287169 , -0.11683381, -0.07786252, -0.03652582, -0.0815727 ,
         0.04725248,  0.13661352, -0.04126461,  0.0648095 , -0.28296906],
       [-0.08024358, -0.12779033, -0.0186054 ,  0.05329209, -0.06536503,
         0.04326358,  0.02166901, -0.06254478, -0.02099912, -0.24656236],
       [ 0.08458489, -0.12221658,  0.07188946, -0.01981514,  0.01983686,
         0.12816292, -0.00959651, -0.03092006,  0.01995439, -0.1335743 ],
       [ 0.02200603, -0.21340764,  0.10282282,  0.15152301, -0.08658978,
         0.04501382,  0.04275166, -0.14820972, -0.14413334, -0.22990748],
       [-0.02198905, -0.04557067,  0.11271986, -0.11075692, -0.06364421,
         0.05326868, -0.06916784,  0.01903618,  0.08512525, -0.09386544],
       [ 0.00166151,  0.0154434 ,  0.15737936, -0.0189851 , -0.08379708,
         0.09193709,  0.07637928, -0.03276146, -0.07838555, -0.16808079],
       [-0.04008571, -0.06368803,  0.01952498,  0.07385585, -0.0519735 ,
         0.077889  ,  0.07414088, -0.07717514, -0.10453714, -0.23420535],
       [-0.02995382, -0.01897652,  0.18775763, -0.00207931, -0.11569525,
        -0.00693431, -0.03604833, -0.025563  ,  0.08033463, -0.43030012],
       [-0.1470493 ,  0.01611945,  0.03033682, -0.06985208,  0.00528503,
         0.11077203, -0.03070154,  0.09006549,  0.01471576, -0.16579953],
       [-0.08504201, -0.07248062,  0.03537239, -0.0369932 , -0.09552636,
         0.03304154,  0.20565894, -0.07788776,  0.03163466, -0.22605082],
       [ 0.09686472, -0.11042653,  0.15155907,  0.01395462, -0.12393546,
         0.00846791,  0.08766708, -0.05175596, -0.06156413, -0.2353364 ],
       [-0.05071541, -0.04816116,  0.07095262, -0.03019926,  0.00299827,
         0.12138566, -0.04782563, -0.00227739,  0.06629352, -0.16230187],
       [-0.00389401,  0.06519604,  0.06551743, -0.07046416, -0.02309355,
         0.05578852, -0.01843534,  0.05648013, -0.02327764, -0.16490978],
       [-0.05213002,  0.04490723,  0.05639035, -0.08413115, -0.01891122,
         0.07520542,  0.02169598,  0.03918428, -0.05883943, -0.15848005],
       [-0.1175741 , -0.06178616,  0.1186999 ,  0.04556429, -0.08410908,
         0.11694017,  0.02081084, -0.02630647, -0.02287086, -0.13323447],
       [-0.00247236, -0.10758375,  0.04810894,  0.0186372 , -0.04536433,
         0.01386689,  0.05767115, -0.14791854, -0.01040876, -0.17474318],
       [-0.13717747, -0.16606307,  0.03670517,  0.20450178, -0.03616795,
         0.11373236,  0.02372343, -0.05522861, -0.07515197, -0.13478407],
       [ 0.04213857, -0.03004314,  0.14711903, -0.08224294, -0.08938472,
         0.03936704,  0.02627639, -0.01311029,  0.07761003, -0.139985  ],
       [-0.04575709,  0.02908578,  0.00688743, -0.08305559, -0.0269957 ,
         0.07010705, -0.00855571,  0.02523912, -0.04845823, -0.15019837],
       [-0.1127076 ,  0.00092046,  0.05648138,  0.00197556, -0.04494622,
         0.0176085 ,  0.02484498,  0.01815733,  0.05459265, -0.27601105],
       [-0.04552835, -0.03394442,  0.2357163 , -0.01331029, -0.00735985,
         0.07026963,  0.08148949, -0.07085205,  0.04734626, -0.1572459 ],
       [-0.1682607 , -0.13536797, -0.00900383, -0.11871683, -0.01245892,
         0.07945751,  0.17478594, -0.02215804,  0.09954034, -0.22022183],
       [-0.05005278,  0.04173126,  0.0536218 , -0.075551  , -0.05639282,
         0.02294743,  0.04530615,  0.01235813, -0.10502443, -0.20111439],
       [-0.02495864, -0.19908276,  0.08468878,  0.10310897, -0.15697284,
        -0.02932639, -0.01359358,  0.03833605, -0.09492809, -0.24088831],
       [-0.05658239, -0.02568415,  0.0844393 , -0.09471586,  0.03544053,
         0.1013931 ,  0.01293087,  0.02833666, -0.01487281, -0.19686376],
       [-0.08039665, -0.13953386,  0.08977354,  0.17283861, -0.08973014,
         0.05862573, -0.04796278, -0.096262  ,  0.06553136, -0.22311056],
       [ 0.05211451, -0.0499455 ,  0.17860122, -0.02966794, -0.11960588,
        -0.02360127, -0.08064819,  0.0388051 ,  0.03446621, -0.3427848 ],
       [ 0.01560623,  0.01230526,  0.10550052, -0.01491977, -0.05740794,
         0.05162506, -0.00342572,  0.08392474,  0.04462853, -0.25932187],
       [ 0.00664159,  0.03559751,  0.13128954,  0.00820261, -0.05523284,
         0.06153762,  0.0171609 ,  0.01113782, -0.0199614 , -0.08247938],
       [-0.06661854, -0.10679328,  0.16954547,  0.02007622, -0.18793562,
         0.01537754,  0.04251043,  0.03742857,  0.04684321, -0.31569415],
       [-0.05449876, -0.16123614,  0.08592527,  0.10152152, -0.05468833,
         0.05138553,  0.06427715, -0.04681365,  0.03377055, -0.21373469],
       [-0.1349617 ,  0.06846139,  0.13728632, -0.02105965, -0.03193126,
         0.17700602,  0.10054564,  0.00160656, -0.00920562, -0.2622624 ],
       [ 0.00145349, -0.19361266,  0.17052989,  0.06735282, -0.12528062,
        -0.03435481,  0.03527531, -0.06399009, -0.167407  , -0.21836808],
       [ 0.02909416, -0.05092278,  0.21206276,  0.02112886, -0.10762815,
        -0.02272951, -0.02688141,  0.10347199,  0.04323527, -0.12265196],
       [-0.13894093, -0.08943977, -0.07184411, -0.07746585, -0.09789695,
         0.06287996,  0.1378551 , -0.14421883,  0.0875184 , -0.22384655],
       [-0.08915859,  0.10834949,  0.15318877, -0.06765133,  0.08239236,
         0.17874372,  0.00596207,  0.11035417,  0.12501304, -0.10164937],
       [-0.06416424, -0.05682167,  0.14469185, -0.06755608, -0.07797025,
         0.12061951,  0.02685872, -0.06390911,  0.11811011, -0.16710874],
       [ 0.0052513 , -0.07773671, -0.00776505, -0.01327882, -0.0308751 ,
         0.05901426,  0.02541669, -0.06283718, -0.07745493, -0.1425428 ],
       [-0.08401803, -0.03138553,  0.03827956, -0.05427635, -0.07102932,
         0.08863396,  0.13484547, -0.09578835,  0.09038509, -0.20963821],
       [ 0.02169423, -0.0730179 ,  0.09176913, -0.01419144, -0.08356085,
         0.0762079 ,  0.0466857 , -0.07560132,  0.10954135, -0.32165968],
       [ 0.00462115, -0.06759334,  0.14088488, -0.01343207, -0.01847196,
         0.11051806,  0.06660537, -0.12575191, -0.0217132 , -0.2082879 ],
       [-0.04550308, -0.07575703,  0.14320102, -0.14648715, -0.12081976,
         0.00810862,  0.11908692, -0.01190487,  0.10026762, -0.3006044 ],
       [ 0.01165795, -0.08626924,  0.05620405,  0.00065827, -0.03311279,
         0.09165145,  0.10745132, -0.10882889, -0.01949929, -0.16950786],
       [-0.08528612, -0.05171133,  0.18769988, -0.00434185, -0.02478376,
         0.11507122, -0.06345908,  0.0196247 ,  0.10161684, -0.07796416],
       [ 0.0389514 , -0.11548889,  0.19540067,  0.07291568, -0.1618273 ,
         0.04614258,  0.03716869, -0.01372258, -0.00645554, -0.31877935],
       [ 0.0357567 , -0.06215338,  0.19724952,  0.03066552, -0.05183752,
         0.05180987,  0.1381061 , -0.11385053, -0.04381441, -0.27968875],
       [-0.01655988, -0.01554613,  0.12607197, -0.0075617 ,  0.02376728,
         0.05720028,  0.04798015, -0.02381728, -0.01830637, -0.20241779],
       [-0.08650109, -0.09197162,  0.14032865,  0.07020658, -0.09812308,
         0.00492898,  0.05049105,  0.01325693,  0.01997352, -0.22332421]],
      dtype=float32)>,
  2: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-1.31927982e-01,  2.99267918e-02,  1.02791674e-01,
         1.13017969e-02, -3.12998295e-02,  1.11763813e-01,
         1.08509474e-01, -4.64202501e-02, -4.78618592e-03,
        -2.96139300e-01],
       [-9.23422799e-02, -3.91177833e-03,  1.12827897e-01,
        -7.53689259e-02,  1.18766492e-02,  1.27979845e-01,
         1.71304196e-02,  5.68249635e-03,  4.40461226e-02,
        -2.09055200e-01],
       [-1.36937778e-02, -1.60679430e-01,  5.70434369e-02,
         4.82278615e-02, -1.19668320e-01, -2.54901126e-02,
         4.58603352e-02, -3.59684527e-02, -2.71879211e-02,
        -3.36850554e-01],
       [-1.25577986e-01, -3.35626937e-02,  8.76025110e-02,
        -6.07250258e-03, -4.74025905e-02,  8.62924308e-02,
         1.09996140e-01, -3.04777920e-03,  7.94584826e-02,
        -2.57510453e-01],
       [-5.38451821e-02,  3.34944278e-02,  1.15943871e-01,
        -5.92793636e-02, -7.42469914e-03,  1.24184616e-01,
         1.29108161e-01,  7.61713684e-02, -3.83588448e-02,
        -2.23109096e-01],
       [-1.35574296e-01, -6.96650371e-02,  5.87333739e-03,
         1.35618020e-02, -2.34017372e-02,  1.31747812e-01,
         6.37701899e-03, -8.94829407e-02,  2.03509629e-01,
        -3.24395239e-01],
       [-5.20405024e-02, -1.18999295e-02,  7.62922540e-02,
        -4.12799343e-02,  3.92224938e-02,  2.00649649e-01,
         3.36785465e-02, -1.01079457e-02, -4.02318053e-02,
        -1.40806943e-01],
       [-7.16783851e-02, -1.08378306e-01,  5.53451777e-02,
         9.84021202e-02,  1.47578977e-02,  1.55778751e-02,
        -1.73141677e-02, -5.98035157e-02, -7.79365003e-02,
        -2.30279684e-01],
       [ 4.13315855e-02, -2.95323618e-02,  1.37910545e-01,
        -5.74592799e-02, -5.37400246e-02,  7.19866157e-03,
        -3.85879353e-03,  9.30407196e-02,  1.54977422e-02,
        -1.58615112e-01],
       [-5.41910194e-02,  8.29618573e-02,  8.86218995e-02,
        -1.25304490e-01, -4.49769571e-02,  1.10043555e-01,
         3.74448076e-02,  3.05313058e-02, -7.12532848e-02,
        -2.40479320e-01],
       [-9.85056162e-02,  8.99460167e-03,  5.08940183e-02,
        -1.06350884e-01, -6.21864274e-02,  5.33725321e-02,
        -1.17418841e-02,  2.17847023e-02,  5.37730940e-02,
        -2.45024264e-01],
       [ 4.07541543e-02, -1.33765936e-01,  1.09994896e-01,
         2.95611806e-02,  7.95577187e-03,  8.70237947e-02,
        -4.05438319e-02, -1.09169044e-01,  2.96825673e-02,
        -1.04853645e-01],
       [-7.81992078e-02, -3.54774892e-02,  1.48391753e-01,
         4.59513590e-02, -6.12238422e-02,  1.01432897e-01,
         1.99210644e-01, -1.75648555e-02, -1.39748026e-02,
        -3.23453277e-01],
       [-1.14587411e-01,  3.45525444e-02,  1.19312018e-01,
        -6.27360940e-02, -7.58168623e-02,  3.49218175e-02,
         4.01948541e-02,  3.24947760e-04,  1.02774724e-01,
        -2.99923509e-01],
       [-8.70037898e-02, -7.11940527e-02, -3.82863134e-02,
        -7.47348964e-02, -7.51427263e-02,  1.66912861e-02,
        -6.79423138e-02, -1.95411257e-02,  1.20868877e-01,
        -3.44241738e-01],
       [-2.54542589e-01, -3.40743810e-02, -1.43570192e-02,
        -1.23573475e-01,  5.43391481e-02,  1.51960373e-01,
        -8.12942088e-02,  3.19334120e-02,  1.29365653e-01,
        -1.46666661e-01],
       [-1.24278814e-02,  3.00877113e-02,  1.48192912e-01,
        -2.84049176e-02, -9.92633924e-02, -8.79112631e-04,
         2.92625874e-02, -2.25068145e-02,  4.30260226e-02,
        -3.23859096e-01],
       [-1.01887852e-01, -1.22221090e-01,  1.41526029e-01,
         5.18470183e-02, -3.47928926e-02,  5.84272631e-02,
         1.07298344e-01, -8.00721496e-02,  2.11770292e-02,
        -1.58377737e-01],
       [ 8.00552964e-02, -1.18967719e-01,  1.41138405e-01,
         7.58778006e-02, -1.01234481e-01, -4.87289429e-02,
        -5.52615412e-02, -6.77794814e-02,  1.11162737e-02,
        -1.31884813e-01],
       [-2.63284482e-02,  1.23943485e-01,  6.15519285e-02,
        -1.95434690e-02, -7.19256848e-02,  9.81644392e-02,
         1.59366056e-02,  1.70592293e-02,  1.20496526e-02,
        -3.53745639e-01],
       [-1.27118051e-01, -1.16495289e-01,  1.16076544e-02,
        -1.37152344e-01, -6.17223941e-02,  1.00205749e-01,
         1.58786803e-01, -2.97708586e-02,  7.83804432e-03,
        -1.93920076e-01],
       [-1.21620297e-01, -4.39747162e-02,  1.46516353e-01,
         2.30680127e-02,  1.38359517e-02, -2.14859843e-02,
         8.30336437e-02, -2.48198155e-02,  3.63006145e-02,
        -1.58849642e-01],
       [-1.79289103e-01,  5.90423495e-03,  7.34030381e-02,
        -8.64828229e-02, -3.70204672e-02,  8.96793082e-02,
         1.21036842e-02,  8.08985531e-02, -6.46586195e-02,
        -2.09532529e-01],
       [-3.23060155e-03, -4.78113964e-02,  5.37485182e-02,
        -1.58086363e-02, -7.87170976e-02,  1.00667775e-03,
        -2.02964842e-02, -1.67324245e-02, -6.24206811e-02,
        -2.20701784e-01],
       [ 4.04927693e-03, -1.13871880e-01,  2.05220401e-01,
         1.32131383e-01, -7.64979348e-02,  6.49123862e-02,
        -3.68607938e-02, -1.69878453e-02,  1.45333305e-01,
        -2.73581028e-01],
       [-1.46519780e-01,  2.89793499e-02,  1.75667435e-01,
        -5.56092672e-02, -9.39915031e-02,  5.78897074e-02,
         1.03417955e-01, -6.95530884e-03,  1.08602785e-01,
        -3.11613858e-01],
       [ 3.49719077e-04, -4.73171473e-02,  2.04560548e-01,
         1.23362206e-02, -8.28021914e-02,  1.39382243e-01,
         6.15775213e-02, -7.69056529e-02, -3.72073725e-02,
        -1.85109019e-01],
       [ 4.59399708e-02, -1.45033717e-01,  1.53510854e-01,
         5.37299402e-02, -8.77758414e-02,  8.98126140e-02,
        -3.28433216e-02, -2.25196928e-02,  2.93625258e-02,
        -2.12388754e-01],
       [-9.48029086e-02, -1.18537471e-01,  8.13599378e-02,
        -1.08892776e-01, -2.44731586e-02,  5.60038723e-02,
         1.03811890e-01, -1.54534420e-02,  4.21965234e-02,
        -1.26404747e-01],
       [ 1.18354015e-01, -1.07880130e-01,  1.50448635e-01,
        -3.92492190e-02, -9.00383741e-02,  3.79986651e-02,
         8.14098269e-02,  2.75722146e-03,  3.57066542e-02,
        -3.12647343e-01],
       [-1.20544016e-01, -5.01988791e-02,  4.44822796e-02,
        -2.06980873e-02, -5.90143465e-02,  5.86909577e-02,
         1.56843930e-01, -4.91900481e-02,  5.57723641e-02,
        -1.87131226e-01],
       [-1.94707848e-02, -3.09111029e-02,  1.27355903e-01,
        -2.00778879e-02, -7.52782822e-02,  1.76905580e-02,
         5.35737388e-02, -7.19204172e-02, -7.88107663e-02,
        -2.73710012e-01],
       [-7.66897425e-02, -1.10425018e-02,  8.37464631e-02,
        -5.07526100e-04, -1.74719512e-01,  1.09259151e-02,
         8.31328332e-02, -6.13105968e-02, -2.88544334e-02,
        -3.19781363e-01],
       [ 4.15514745e-02, -1.79399341e-01,  1.29595235e-01,
        -7.51352310e-03, -1.04985982e-01, -2.32766867e-02,
         9.11241770e-02, -4.83584106e-02, -4.45627645e-02,
        -3.11877191e-01],
       [-1.22907385e-01, -8.35648179e-02, -4.61399555e-04,
         3.62833142e-02, -6.16921745e-02,  3.41793895e-02,
         6.09514117e-03, -2.58133207e-02,  3.86073887e-02,
        -1.72697291e-01],
       [-1.47430956e-01, -1.70899138e-01, -2.54742727e-02,
        -1.93053037e-02, -9.34216827e-02,  8.63033384e-02,
         1.95844889e-01, -1.20219477e-02, -7.42562041e-02,
        -2.68633634e-01],
       [-1.89905185e-02, -2.92911306e-02,  1.06189132e-01,
        -2.14210935e-02, -1.22999787e-01,  8.50224569e-02,
         9.03043598e-02, -1.00861326e-01,  7.86628351e-02,
        -1.96828574e-01],
       [-1.53984241e-02, -1.78240817e-02,  9.94455144e-02,
        -8.07050020e-02, -1.08934626e-01, -5.55047393e-03,
         3.86882052e-02,  5.02423719e-02, -2.03366280e-02,
        -1.51539668e-01],
       [ 9.89805907e-03, -1.00415491e-01,  3.55034679e-01,
        -7.64561146e-02,  1.09619442e-02,  7.78562427e-02,
        -2.41442919e-02, -7.27682002e-03,  1.75281227e-01,
        -1.17960580e-01],
       [-1.99002251e-01, -1.73170697e-02,  9.90331173e-03,
         5.70956320e-02, -1.32634252e-01,  6.45322800e-02,
         6.06766045e-02, -3.72909233e-02,  1.13204762e-01,
        -3.23195934e-01],
       [-8.26755986e-02, -1.36185214e-02,  9.65979844e-02,
        -9.37095098e-03, -9.52767730e-02,  1.47174925e-01,
         5.02739623e-02, -8.05822015e-03,  4.75806966e-02,
        -1.34617299e-01],
       [-7.80887008e-02, -5.81236407e-02, -5.40275350e-02,
        -7.83395171e-02, -3.40201519e-02,  1.07036822e-01,
         1.53975934e-01, -7.88928047e-02,  6.83138445e-02,
        -2.24947393e-01],
       [-1.68145984e-01,  1.80413537e-02,  7.25683197e-02,
        -8.54006410e-02, -5.02417013e-02,  8.27518255e-02,
         1.25536025e-02, -1.24789141e-02,  1.19018570e-01,
        -2.18624309e-01],
       [-1.93841606e-01,  2.02290900e-02,  2.84137726e-02,
         1.44029409e-03, -1.27114952e-02,  1.46217287e-01,
         1.24608524e-01,  8.01332593e-02,  2.23830119e-02,
        -2.26353824e-01],
       [-1.14570610e-01,  1.37106329e-03,  1.42955720e-01,
        -3.47122401e-02, -6.64473474e-02,  8.64062235e-02,
         1.44125596e-02,  1.29322475e-02,  9.87597778e-02,
        -2.46266648e-01],
       [ 1.29366428e-01, -9.02915299e-02,  8.26685205e-02,
        -1.06159896e-02, -4.45675664e-02, -4.79315333e-02,
         5.16594872e-02, -9.11353454e-02,  2.74216384e-02,
        -2.37027109e-01],
       [-3.45051289e-02,  8.24849755e-02,  6.68868348e-02,
        -5.46753630e-02, -9.83701795e-02,  9.21582431e-02,
         7.03592375e-02, -4.50211540e-02, -6.24448322e-02,
        -1.69667661e-01],
       [-1.85233742e-01, -1.67721972e-01, -5.34474105e-02,
        -4.94348034e-02, -5.77729978e-02,  5.78697175e-02,
         4.81551588e-02, -1.85359903e-02,  1.03780642e-01,
        -1.09119669e-01],
       [ 2.79906467e-02, -1.25610203e-01,  1.95233941e-01,
        -6.24950975e-04, -1.49155289e-01, -6.30461425e-02,
         3.01934332e-02,  4.90641892e-02, -2.38361508e-02,
        -1.47956103e-01],
       [-2.29758769e-02, -1.20755121e-01,  7.72470385e-02,
         6.54768348e-02, -9.47252512e-02,  7.46772215e-02,
         8.32523108e-02, -2.99419723e-02, -7.57840723e-02,
        -1.46706983e-01],
       [ 5.91882318e-03, -7.30389804e-02,  7.30036646e-02,
         3.24693620e-02, -1.04020640e-01, -3.04035842e-04,
         5.88735193e-02, -1.92306414e-02,  5.23711070e-02,
        -3.50282133e-01],
       [-1.20007187e-01, -7.61702210e-02,  9.75774527e-02,
         4.98527810e-02,  1.01156048e-02,  1.00960128e-01,
        -1.75609887e-02,  2.23743711e-02, -1.93186104e-03,
        -1.20586514e-01],
       [ 3.37485932e-02, -4.85759452e-02,  1.16713658e-01,
        -9.13519040e-03, -8.60348865e-02,  2.75263786e-02,
         4.08158600e-02,  2.31841542e-02, -9.30296183e-02,
        -1.24190323e-01],
       [-3.25062051e-02, -8.56647715e-02,  4.35702205e-02,
         5.17548621e-03, -7.83292949e-02, -3.01022921e-03,
         2.29148418e-02, -9.24165100e-02,  7.12836385e-02,
        -3.15343291e-01],
       [-5.13982773e-02,  4.44307514e-02,  1.43186092e-01,
        -4.81466763e-02, -4.43525612e-04,  1.55988052e-01,
         6.54639900e-02,  7.60897845e-02, -1.06668629e-01,
        -1.12153105e-01],
       [-3.10686305e-02, -9.24458653e-02,  2.01833948e-01,
         1.68880932e-02, -1.13820083e-01, -7.02955872e-02,
         3.07826363e-02,  4.15768698e-02,  4.62568998e-02,
        -1.73747614e-01],
       [ 1.47334374e-02, -4.67091314e-02,  9.16878283e-02,
         8.49749334e-03, -1.00595549e-01,  4.56146933e-02,
         4.77868319e-03,  1.34228319e-02, -8.27117935e-02,
        -1.85992092e-01],
       [-4.64355089e-02, -1.30060539e-01,  1.46310270e-01,
         9.11136493e-02, -9.22796726e-02,  3.81803960e-02,
         1.29324630e-01, -6.09851331e-02, -2.76935361e-02,
        -3.45277220e-01],
       [-1.84381716e-02,  7.08288103e-02,  5.04010245e-02,
        -3.02325673e-02, -1.33312596e-02,  9.46698114e-02,
        -5.11020906e-02,  3.12014055e-02, -2.78981961e-02,
        -1.42308325e-01],
       [-5.07387146e-02, -1.70747623e-01,  1.13214880e-01,
         3.89950573e-02, -8.52145851e-02,  1.81435719e-02,
         1.14593610e-01, -2.37021297e-02, -1.33675352e-01,
        -1.47563219e-01],
       [-1.11285999e-01, -4.64472324e-02,  1.05362996e-01,
         8.69985670e-03, -1.82327237e-02,  9.34701115e-02,
         1.04373902e-01, -4.88889180e-02,  9.24790651e-02,
        -2.04346865e-01],
       [ 2.68447474e-02, -6.55717775e-03,  9.68200266e-02,
        -9.11999028e-03, -5.37653975e-02,  6.85019046e-02,
         1.95344985e-02, -6.28999807e-03, -4.93589118e-02,
        -3.02899867e-01],
       [-8.84965062e-03, -8.10101628e-02,  1.57090068e-01,
        -5.16995788e-04, -7.99443573e-04,  1.61357388e-01,
         1.98418960e-01, -5.81040978e-02, -4.12518829e-02,
        -2.11800396e-01],
       [-1.61314569e-02, -7.50841498e-02,  6.72838911e-02,
        -1.77042354e-02, -7.52817765e-02,  9.92488489e-02,
         4.45844084e-02, -1.35439575e-01,  1.27033405e-02,
        -2.12368548e-01]], dtype=float32)>,
  3: <tf.Tensor: shape=(64, 10), dtype=float32, numpy=
array([[-2.64669694e-02, -3.38262543e-02,  1.42635211e-01,
         1.82953551e-02, -6.10583946e-02,  7.38784075e-02,
         1.63592696e-01, -4.28866819e-02,  7.73158297e-03,
        -2.86082923e-01],
       [-2.39099804e-02, -6.75053671e-02,  8.67765918e-02,
        -5.18340617e-03, -4.08537760e-02,  9.38599929e-02,
         8.11848789e-02, -8.55603665e-02, -4.94918078e-02,
        -2.03883007e-01],
       [-1.65042691e-02, -8.53570178e-02,  3.14284228e-02,
        -4.69481945e-02, -8.70868415e-02, -4.40990850e-02,
         5.28762639e-02, -1.48003884e-02, -3.02395672e-02,
        -1.61814839e-01],
       [-6.31767660e-02, -7.45066851e-02,  7.32449219e-02,
         4.72594053e-02, -8.39884877e-02,  9.71324928e-03,
         5.33771329e-02, -4.27164137e-02, -5.03900573e-02,
        -1.95417494e-01],
       [-1.29458681e-02, -1.00372039e-01,  1.38672113e-01,
        -4.69614267e-02, -1.36119291e-01,  4.96306829e-02,
         7.46882409e-02,  1.35311186e-02, -6.82607144e-02,
        -2.42756695e-01],
       [-8.25050473e-02, -1.18047968e-01,  1.00704893e-01,
         4.08207513e-02, -9.70342532e-02,  1.44894235e-02,
         4.92111444e-02,  4.06514183e-02, -5.72888181e-03,
        -1.23309992e-01],
       [ 5.86282983e-02, -2.69978307e-02,  1.71814293e-01,
         3.68928164e-03, -6.51650578e-02,  1.17143244e-03,
        -2.73939073e-02,  8.62740725e-02,  1.19505702e-02,
        -1.23816878e-01],
       [-1.30836681e-01,  3.53544727e-02,  9.63566601e-02,
        -6.18131086e-03, -5.96602857e-02,  1.32048532e-01,
        -2.84749344e-02,  1.05161376e-01,  1.62198126e-01,
        -3.30515325e-01],
       [-1.66418150e-01,  6.26565441e-02,  6.00432120e-02,
         4.81056049e-03, -3.37211862e-02,  1.86545014e-01,
         5.09998798e-02, -4.90842164e-02,  9.29402038e-02,
        -1.85279399e-01],
       [-1.55449226e-01, -3.42682600e-02, -1.55854866e-01,
        -1.00282878e-01,  5.85962161e-02,  1.07017875e-01,
        -2.67901272e-03, -6.94379359e-02, -5.64534515e-02,
        -1.63246885e-01],
       [-1.42344698e-01, -7.94031918e-02, -8.49062502e-02,
        -1.67585760e-02, -5.26322424e-02,  1.15996465e-01,
         1.77681834e-01, -4.55200300e-02,  5.84908314e-02,
        -2.40311787e-01],
       [-1.07000284e-02, -1.73099279e-01,  1.12982728e-01,
         1.24436900e-01, -1.09485939e-01,  3.87789421e-02,
         1.20341927e-02, -1.15824141e-01, -7.10867569e-02,
        -1.89201519e-01],
       [ 1.05197523e-02, -1.55473240e-02,  2.17849568e-01,
        -4.39790227e-02, -7.89612085e-02, -1.54420212e-02,
         1.34766832e-01, -5.79421744e-02,  7.30541497e-02,
        -3.16277683e-01],
       [ 4.93944436e-03, -8.88994709e-02,  1.27059951e-01,
        -4.78387177e-02, -1.09570041e-01,  7.09091723e-02,
         1.66771024e-01, -5.60563169e-02, -7.21291900e-02,
        -1.62604123e-01],
       [-7.66397566e-02, -1.27806276e-01,  1.68210953e-01,
         5.10694049e-02, -2.01880224e-02,  1.95342675e-02,
         5.61096258e-02, -9.97482799e-03, -6.18980825e-03,
        -1.81726933e-01],
       [-7.79834092e-02, -1.11465484e-01,  9.39368159e-02,
        -1.04102325e-02,  4.47943434e-03,  1.11190528e-01,
         5.33266366e-02, -5.19691817e-02, -4.69358973e-02,
        -5.68609275e-02],
       [-7.40044191e-03, -6.68793321e-02,  6.43530115e-02,
        -6.44794852e-02, -9.86383930e-02, -4.03227955e-02,
        -2.03765109e-02,  2.55719647e-02,  9.47349072e-02,
        -2.11710915e-01],
       [-4.76778708e-02, -8.36147964e-02, -5.42806797e-02,
        -1.48328990e-02,  7.09218346e-03,  3.76575217e-02,
         6.80175126e-02, -5.39340749e-02, -1.11827053e-01,
        -1.17766850e-01],
       [-7.08758458e-02, -1.50145330e-02,  1.37906641e-01,
        -1.09956324e-01, -8.05506408e-02,  5.87261096e-02,
         7.95594156e-02, -7.02388883e-02,  1.12074450e-01,
        -2.86128819e-01],
       [-1.41564205e-01, -1.84379086e-01,  6.08199090e-03,
        -3.57749760e-02, -1.20927207e-01,  5.45358323e-02,
         2.04979360e-01, -7.04201683e-03,  6.30752146e-02,
        -1.99726343e-01],
       [-2.46486142e-02, -1.43060118e-01,  2.44112194e-01,
         5.22014759e-02, -1.63198858e-01, -8.70878398e-02,
         4.50719818e-02,  3.69304605e-02,  2.86881700e-02,
        -1.97954357e-01],
       [-1.70564368e-01, -1.41653374e-01,  6.92816749e-02,
         3.37080359e-02, -1.05360776e-01,  1.05874792e-01,
         1.87227845e-01, -2.84348801e-03,  5.33786006e-02,
        -2.58223802e-01],
       [-3.00702229e-02, -7.21753836e-02,  1.05540782e-01,
         2.79989503e-02, -1.04631349e-01,  1.10977940e-01,
         8.80912393e-02, -4.90320697e-02,  9.53337848e-02,
        -2.87015498e-01],
       [-2.88813300e-02, -7.96605721e-02,  7.92133957e-02,
         1.69448499e-02,  9.08773020e-03,  1.21535525e-01,
         5.74936569e-02, -6.12248927e-02, -3.65313962e-02,
        -8.45414698e-02],
       [ 6.32017702e-02, -1.16388880e-01,  9.43448991e-02,
        -2.45954767e-02, -1.57698035e-01,  7.75593892e-03,
         7.38503188e-02, -4.63453233e-02,  2.46493015e-02,
        -3.57226938e-01],
       [ 2.41166055e-02, -2.00041413e-01,  1.61876678e-01,
         1.07993588e-01, -1.53217360e-01, -4.54890057e-02,
         4.32299078e-02, -7.34345019e-02, -2.73385979e-02,
        -1.64955348e-01],
       [ 4.98878583e-03, -7.03433901e-02,  1.72625899e-01,
         2.18960494e-02, -1.55882224e-01,  2.26828977e-02,
        -1.05383545e-02, -5.19299880e-02,  1.22356519e-01,
        -1.94377258e-01],
       [-1.44050986e-01, -8.05627480e-02, -8.46281648e-04,
         9.59945098e-03, -8.92172754e-03,  8.86529684e-02,
        -3.30525637e-03, -2.48262249e-02,  7.21365660e-02,
        -2.42367283e-01],
       [-9.61969569e-02,  6.51044399e-02, -1.67442858e-03,
        -4.05182652e-02, -2.98792124e-03,  1.06640391e-01,
         1.31437168e-01, -2.45141536e-02,  9.11239013e-02,
        -1.44164890e-01],
       [-1.05292745e-01,  4.29363176e-02,  3.14764753e-02,
        -1.68635845e-02,  7.07060099e-06,  1.50872976e-01,
         3.06056589e-02,  1.04895234e-01,  4.37264144e-03,
        -2.04268932e-01],
       [-1.37525111e-01, -1.27809107e-01,  1.03715301e-01,
         9.83670950e-02, -4.53438535e-02,  8.12293440e-02,
         1.55392528e-01, -6.62733391e-02, -1.28058463e-01,
        -1.71628654e-01],
       [-1.59386769e-01, -1.52590901e-01, -3.77684981e-02,
        -7.56763108e-03, -7.32103288e-02, -1.75090730e-02,
        -5.36651984e-02,  6.40557036e-02,  2.07698703e-01,
        -2.56075978e-01],
       [-8.20378587e-02, -6.43030629e-02, -2.37012058e-02,
        -7.71675482e-02, -7.12690204e-02,  7.18678385e-02,
         2.22103894e-01, -3.53842340e-02, -1.32079795e-02,
        -2.27164745e-01],
       [-1.88708812e-01, -3.74953412e-02, -6.17561266e-02,
        -4.05938029e-02, -1.33953393e-01,  4.83541340e-02,
         1.07024759e-01,  1.64985433e-02,  1.08813286e-01,
        -2.53184080e-01],
       [ 2.46044882e-02,  3.34643200e-03,  1.00584090e-01,
        -1.31709933e-01, -1.03203796e-01, -2.46407501e-02,
         6.59326911e-02,  3.55283618e-02,  1.53171986e-01,
        -3.35660040e-01],
       [-4.86803763e-02, -8.44964981e-02,  3.50355767e-02,
        -6.84254616e-02, -7.94377029e-02,  1.58106908e-02,
         1.32192954e-01, -6.56674504e-02, -1.68727741e-01,
        -2.00229853e-01],
       [-9.25698318e-03, -1.82153299e-01,  1.12911947e-01,
         7.01041594e-02, -1.06265306e-01, -4.13442552e-02,
        -7.88840652e-03, -2.72199512e-02, -1.30737871e-02,
        -1.64280236e-01],
       [ 6.81568161e-02, -2.14689732e-01,  6.27778843e-02,
         3.86403836e-02, -6.67396709e-02,  4.73842025e-04,
        -1.25432499e-02, -5.98382652e-02,  7.93668032e-02,
        -3.01945388e-01],
       [ 1.69950798e-02, -1.49343759e-01,  1.32372409e-01,
         3.11933160e-02, -1.44183084e-01, -3.27975452e-02,
         5.43530583e-02, -5.53915016e-02,  8.78268927e-02,
        -2.47543320e-01],
       [-1.20600574e-01, -7.36613646e-02,  8.44000056e-02,
        -8.12885314e-02, -1.10457651e-02,  1.10907272e-01,
         1.97365969e-01, -7.86562338e-02,  1.86791271e-01,
        -1.96342722e-01],
       [-3.85482386e-02,  5.80951124e-02,  1.16936497e-01,
        -9.17282552e-02, -5.24022356e-02,  6.13490269e-02,
        -2.44482681e-02,  4.97046672e-02, -1.23902299e-02,
        -2.36063093e-01],
       [ 3.50311697e-02, -6.97146282e-02,  1.00809224e-01,
        -2.48865560e-02, -8.00992921e-02,  1.11433715e-01,
         1.76830754e-01, -9.19253752e-02, -7.60221034e-02,
        -1.55015558e-01],
       [-1.84618682e-02, -2.26229057e-03,  3.76343206e-02,
        -6.90721124e-02, -4.00315486e-02,  7.73715228e-03,
         1.06939673e-01, -4.81865034e-02, -1.28759220e-01,
        -1.55072302e-01],
       [-5.23045063e-02, -1.95138901e-03,  1.31059527e-01,
        -1.24340586e-01, -1.31771639e-02,  1.15858883e-01,
         1.09098136e-01, -6.17204793e-02,  1.20730974e-01,
        -1.88713402e-01],
       [-1.82524756e-01,  3.01376507e-02,  4.21366021e-02,
         1.07705340e-01,  4.07074578e-03,  1.01444110e-01,
         9.33211148e-02, -6.50132671e-02,  1.55968666e-02,
        -2.17603400e-01],
       [-7.34460205e-02, -6.40550628e-02,  1.62612572e-01,
         2.07998231e-03, -1.04698032e-01, -2.52882317e-02,
         1.51614502e-01, -3.01348716e-02, -4.20465022e-02,
        -2.32530192e-01],
       [ 4.32404280e-02, -1.79166093e-01,  1.29130483e-01,
         5.29525988e-02, -4.59525809e-02,  9.48042348e-02,
         9.91417542e-02, -1.02751359e-01, -9.99256819e-02,
        -2.80976355e-01],
       [-1.52932853e-01, -6.58576041e-02,  3.34841684e-02,
        -6.90476596e-02, -8.95142406e-02,  3.76654044e-02,
         1.08571425e-01, -4.89762872e-02,  1.22642308e-01,
        -2.20513642e-01],
       [-1.35309398e-01, -1.12868197e-01, -1.85489237e-01,
         3.84217203e-02, -7.46929795e-02,  3.75274494e-02,
         7.52522573e-02, -1.16239741e-01, -8.69273543e-02,
        -2.17446089e-01],
       [-7.84089938e-02, -1.00249879e-01,  1.32972941e-01,
         7.98643231e-02, -4.04328853e-02,  1.00631088e-01,
         1.49289951e-01, -9.74092036e-02, -3.37632671e-02,
        -3.46492946e-01],
       [-1.76333323e-01, -2.21099615e-01,  9.93300974e-03,
        -4.28270176e-02, -8.48801732e-02,  6.27638400e-02,
         7.00467676e-02,  2.16519833e-03,  1.58375859e-01,
        -1.55701905e-01],
       [-3.81720811e-02,  5.62158227e-02,  4.48521227e-02,
        -4.67639454e-02, -3.47419642e-03,  8.52869600e-02,
         4.70591225e-02,  6.43402040e-02, -3.82010601e-02,
        -1.32496700e-01],
       [-1.07183166e-01, -3.57413329e-02,  1.24337286e-01,
        -8.86357278e-02, -1.20967336e-01, -8.44574161e-03,
         2.80514359e-05, -1.20141786e-02,  6.42660409e-02,
        -2.48643801e-01],
       [ 1.83491241e-02, -6.70650229e-03,  1.63027897e-01,
        -3.07689309e-02, -4.95545827e-02,  1.01515427e-02,
        -3.93696502e-03,  5.27074262e-02, -3.08710895e-03,
        -1.06414281e-01],
       [-5.78492172e-02, -1.77028105e-02,  8.53653997e-02,
         4.94639501e-02, -6.67294562e-02,  1.26378000e-01,
         1.47325665e-01, -3.67106423e-02, -9.10526589e-02,
        -2.02979073e-01],
       [-5.75959831e-02, -1.59533769e-01,  5.35644963e-03,
         8.53684619e-02, -7.71569237e-02,  9.58867073e-02,
         1.82126731e-01, -8.76367614e-02, -1.64111912e-01,
        -2.36350983e-01],
       [ 2.87542194e-02,  4.11205292e-02,  1.54037550e-02,
        -7.39890784e-02, -3.99961285e-02,  5.83926402e-02,
        -1.59002841e-03,  5.08946180e-03,  5.76792881e-02,
        -2.33185023e-01],
       [-1.32336572e-01, -3.18910778e-02,  5.26991636e-02,
         8.75673629e-03, -2.75723618e-02,  1.05614796e-01,
         1.57546744e-01,  9.33468528e-03, -1.27191246e-01,
        -1.26512900e-01],
       [-1.03891142e-01, -5.44731095e-02,  7.06407055e-03,
        -1.32387295e-01, -4.07544859e-02,  1.00440048e-01,
         1.29370183e-01, -8.29259232e-02,  4.68877256e-02,
        -1.68420091e-01],
       [-3.46207023e-02, -5.62616065e-02,  1.02029607e-01,
        -9.13117379e-02,  9.51373428e-02,  1.73789680e-01,
         1.07564330e-02,  1.47677809e-02, -2.23123394e-02,
        -1.27627820e-01],
       [ 3.64530645e-02, -2.00009674e-01,  7.16215372e-02,
         5.68866543e-02, -6.85976148e-02,  4.96518128e-02,
        -2.75990199e-02, -2.76161209e-02, -5.71828224e-02,
        -1.55968279e-01],
       [-6.21210262e-02, -4.52079251e-02,  1.18788272e-01,
        -2.32128687e-02, -1.40547067e-01,  3.64853740e-02,
         1.05344623e-01, -2.46146731e-02, -1.09788045e-01,
        -1.88733384e-01],
       [-4.27781120e-02, -8.95003751e-02,  1.59641579e-01,
        -7.70486332e-03, -6.55359328e-02,  8.60365629e-02,
         1.26972497e-02, -2.87206247e-02, -7.65442178e-02,
        -1.94938064e-01],
       [-1.04442030e-01,  1.78434327e-03,  9.27254930e-02,
        -5.01224697e-02, -8.29514861e-03,  1.24611050e-01,
         3.08808610e-02, -2.22852137e-02,  3.16940621e-02,
        -2.74606556e-01]], dtype=float32)>
} }
2022-10-20 04:29:11.728997: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Calling the restored function is just a forward pass on the saved model (tf.keras.Model.predict). What if you want to continue training the loaded function? Or what if you need to embed the loaded function into a bigger model? A common practice is to wrap this loaded object into a Keras layer to achieve this. Luckily, TF Hub has hub.KerasLayer for this purpose, shown here:

import tensorflow_hub as hub

def build_model(loaded):
  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
  # Wrap what's loaded to a KerasLayer
  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
  model = tf.keras.Model(x, keras_layer)
  return model

another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  loaded = tf.saved_model.load(saved_model_path)
  model = build_model(loaded)

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])
  model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11.
2022-10-20 04:29:12.327410: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:549] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/2
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 6 all-reduces with algorithm = nccl, num_packs = 1
235/235 [==============================] - 7s 10ms/step - loss: 0.3219 - sparse_categorical_accuracy: 0.9119
Epoch 2/2
235/235 [==============================] - 2s 10ms/step - loss: 0.1002 - sparse_categorical_accuracy: 0.9717

In the above example, Tensorflow Hub's hub.KerasLayer wraps the result loaded back from tf.saved_model.load into a Keras layer that is used to build another model. This is very useful for transfer learning.

Which API should I use?

For saving, if you are working with a Keras model, use the Keras Model.save API unless you need the additional control allowed by the low-level API. If what you are saving is not a Keras model, then the lower-level API, tf.saved_model.save, is your only choice.

For loading, your API choice depends on what you want to get from the model loading API. If you cannot (or do not want to) get a Keras model, then use tf.saved_model.load. Otherwise, use tf.keras.models.load_model. Note that you can get a Keras model back only if you saved a Keras model.

It is possible to mix and match the APIs. You can save a Keras model with Model.save, and load a non-Keras model with the low-level API, tf.saved_model.load.

model = get_model()

# Saving the model using Keras `Model.save`
model.save(keras_model_path)

another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using the lower-level API
with another_strategy.scope():
  loaded = tf.saved_model.load(keras_model_path)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _update_step_xla while saving (showing 2 of 2). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

Saving/Loading from a local device

When saving and loading from a local I/O device while training on remote devices—for example, when using a Cloud TPU—you must use the option experimental_io_device in tf.saved_model.SaveOptions and tf.saved_model.LoadOptions to set the I/O device to localhost. For example:

model = get_model()

# Saving the model to a path on localhost.
saved_model_path = '/tmp/tf_save'
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)

# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _update_step_xla while saving (showing 2 of 2). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')

Caveats

One special case is when you create Keras models in certain ways, and then save them before training. For example:

class SubclassedModel(tf.keras.Model):
  """Example model defined by subclassing `tf.keras.Model`."""

  output_name = 'output_layer'

  def __init__(self):
    super(SubclassedModel, self).__init__()
    self._dense_layer = tf.keras.layers.Dense(
        5, dtype=tf.dtypes.float32, name=self.output_name)

  def call(self, inputs):
    return self._dense_layer(inputs)

my_model = SubclassedModel()
try:
  my_model.save(keras_model_path)
except ValueError as e:
  print(f'{type(e).__name__}: ', *e.args)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fa3ec39b190>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fa3ec39b190>, because it is not built.
ValueError:  Model <__main__.SubclassedModel object at 0x7fa3ec39b190> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override `Model.call()`. To specify an input shape, either call `build(input_shape)` directly, or call the model on actual data using `Model()`, `Model.fit()`, or `Model.predict()`. If you have a custom training step, please make sure to invoke the forward pass in train step through `Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.

A SavedModel saves the tf.types.experimental.ConcreteFunction objects generated when you trace a tf.function (check When is a Function tracing? in the Introduction to graphs and tf.function guide to learn more). If you get a ValueError like this it's because Model.save was not able to find or create a traced ConcreteFunction.

tf.saved_model.save(my_model, saved_model_path)
x = tf.saved_model.load(saved_model_path)
x.signatures
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fa3ec39b190>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fa3ec39b190>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7fa3442cc4c0>, because it is not built.
WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7fa3442cc4c0>, because it is not built.
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
INFO:tensorflow:Assets written to: /tmp/tf_save/assets
_SignatureMap({})

Usually the model's forward pass—the call method—will be traced automatically when the model is called for the first time, often via the Keras Model.fit method. A ConcreteFunction can also be generated by the Keras Sequential and Functional APIs, if you set the input shape, for example, by making the first layer either a tf.keras.layers.InputLayer or another layer type, and passing it the input_shape keyword argument.

To verify if your model has any traced ConcreteFunctions, check if Model.save_spec is None:

print(my_model.save_spec() is None)
True

Let's use tf.keras.Model.fit to train the model, and notice that the save_spec gets defined and model saving will work:

BATCH_SIZE_PER_REPLICA = 4
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync

dataset_size = 100
dataset = tf.data.Dataset.from_tensors(
    (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))
    ).repeat(dataset_size).batch(BATCH_SIZE)

my_model.compile(optimizer='adam', loss='mean_squared_error')
my_model.fit(dataset, epochs=2)

print(my_model.save_spec() is None)
my_model.save(keras_model_path)
Epoch 1/2
7/7 [==============================] - 1s 3ms/step - loss: 13.5012
Epoch 2/2
7/7 [==============================] - 0s 2ms/step - loss: 13.0349
False
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa34411deb0>, 140342019215248), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa34411deb0>, 140342019215248), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa34409d670>, 140339738260944), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa34409d670>, 140339738260944), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa34411deb0>, 140342019215248), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa34411deb0>, 140342019215248), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa34409d670>, 140339738260944), {}).
INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fa34409d670>, 140339738260944), {}).
WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/keras_save/assets
INFO:tensorflow:Assets written to: /tmp/keras_save/assets