![]() |
![]() |
![]() |
![]() |
Overview
This tutorial demonstrates how you can save and load models in a SavedModel format with tf.distribute.Strategy
during or after training. There are two kinds of APIs for saving and loading a Keras model: high-level (tf.keras.Model.save
and tf.keras.models.load_model
) and low-level (tf.saved_model.save
and tf.saved_model.load
).
To learn about SavedModel and serialization in general, please read the saved model guide, and the Keras model serialization guide. Let's start with a simple example.
Import dependencies:
import tensorflow_datasets as tfds
import tensorflow as tf
Load and prepare the data with TensorFlow Datasets and tf.data
, and create the model using tf.distribute.MirroredStrategy
:
mirrored_strategy = tf.distribute.MirroredStrategy()
def get_data():
datasets = tfds.load(name='mnist', as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
return train_dataset, eval_dataset
def get_model():
with mirrored_strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
Train the model with tf.keras.Model.fit
:
model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
2023-07-27 06:33:51.848417: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:551] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. Epoch 1/2 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 231/235 [============================>.] - ETA: 0s - loss: 0.3387 - sparse_categorical_accuracy: 0.9077INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 235/235 [==============================] - 9s 7ms/step - loss: 0.3359 - sparse_categorical_accuracy: 0.9084 Epoch 2/2 235/235 [==============================] - 2s 7ms/step - loss: 0.0979 - sparse_categorical_accuracy: 0.9723 <keras.src.callbacks.History at 0x7fb92c29e460>
Save and load the model
Now that you have a simple model to work with, let's explore the saving/loading APIs. There are two kinds of APIs available:
- High-level (Keras):
Model.save
andtf.keras.models.load_model
(.keras
zip archive format) - Low-level:
tf.saved_model.save
andtf.saved_model.load
(TF SavedModel format)
The Keras API
Here is an example of saving and loading a model with the Keras API:
keras_model_path = '/tmp/keras_save.keras'
model.save(keras_model_path)
Restore the model without tf.distribute.Strategy
:
restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2 235/235 [==============================] - 2s 3ms/step - loss: 0.0652 - sparse_categorical_accuracy: 0.9814 Epoch 2/2 235/235 [==============================] - 1s 3ms/step - loss: 0.0505 - sparse_categorical_accuracy: 0.9852 <keras.src.callbacks.History at 0x7fbae4011a30>
After restoring the model, you can continue training on it, even without needing to call Model.compile
again, since it was already compiled before saving. The model is saved a Keras zip archive format, marked by the .keras
extension. For more information, please refer to the guide on Keras saving.
Now, restore the model and train it using a tf.distribute.Strategy
:
another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
with another_strategy.scope():
restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2 2023-07-27 06:34:05.964950: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:551] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. 2023-07-27 06:34:06.021374: W tensorflow/core/framework/dataset.cc:956] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations. 235/235 [==============================] - 3s 12ms/step - loss: 0.0642 - sparse_categorical_accuracy: 0.9819 Epoch 2/2 235/235 [==============================] - 3s 12ms/step - loss: 0.0504 - sparse_categorical_accuracy: 0.9854
As the Model.fit
output shows, loading works as expected with tf.distribute.Strategy
. The strategy used here does not have to be the same strategy used before saving.
The tf.saved_model
API
Saving the model with lower-level API is similar to the Keras API:
model = get_model() # get a fresh model
saved_model_path = '/tmp/tf_save'
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets
Loading can be done with tf.saved_model.load
. However, since it is a lower-level API (and hence has a wider range of use cases), it does not return a Keras model. Instead, it returns an object that contain functions that can be used to do inference. For example:
DEFAULT_FUNCTION_KEY = 'serving_default'
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
The loaded object may contain multiple functions, each associated with a key. The "serving_default"
key is the default key for the inference function with a saved Keras model. To do inference with this function:
predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(256, 10), dtype=float32, numpy= array([[ 0.0277961 , -0.25427482, 0.41693777, ..., -0.09557065, -0.15324523, 0.03819636], [-0.336356 , -0.41538522, 0.24396369, ..., -0.16313636, -0.07839881, 0.16986977], [-0.03182657, -0.33101833, 0.19745234, ..., -0.08597949, -0.11427765, 0.07309295], ..., [-0.1376599 , -0.23535399, 0.13248071, ..., -0.00206384, -0.04547129, 0.16397868], [-0.12199271, -0.26029754, 0.02518643, ..., -0.05039004, -0.07103124, 0.16076234], [-0.13545126, -0.2061567 , 0.19009882, ..., -0.08650086, -0.0093526 , 0.19425942]], dtype=float32)>} 2023-07-27 06:34:12.871781: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
You can also load and do inference in a distributed manner:
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
dist_predict_dataset = another_strategy.experimental_distribute_dataset(
predict_dataset)
# Calling the function in a distributed manner
for batch in dist_predict_dataset:
result = another_strategy.run(inference_func, args=(batch,))
print(result)
break
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') 2023-07-27 06:34:13.062776: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:551] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. {'dense_3': PerReplica:{ 0: <tf.Tensor: shape=(64, 10), dtype=float32, numpy= array([[ 2.77961008e-02, -2.54274815e-01, 4.16937768e-01, -3.72880578e-01, -3.29998657e-02, -6.58259764e-02, 5.94227128e-02, -9.55706537e-02, -1.53245226e-01, 3.81963551e-02], [-3.36356014e-01, -4.15385216e-01, 2.43963689e-01, -3.62907380e-01, -2.37839982e-01, 1.35413557e-02, 9.63806435e-02, -1.63136363e-01, -7.83988088e-02, 1.69869766e-01], [-3.18265669e-02, -3.31018329e-01, 1.97452337e-01, -2.33169883e-01, -1.30116433e-01, 1.72409430e-01, 1.78383052e-01, -8.59794915e-02, -1.14277646e-01, 7.30929524e-02], [-1.65744871e-02, -2.28164136e-01, 2.77862370e-01, -1.85078382e-01, -7.64370710e-02, 9.82998610e-02, 1.01667866e-02, -6.43879622e-02, -8.44108313e-02, -1.03641301e-02], [ 4.22359444e-02, -1.69978052e-01, 1.71413556e-01, -1.52895629e-01, -1.03923768e-01, 3.96954454e-02, 1.12688929e-01, -7.81170577e-02, -7.41601288e-02, 1.25813678e-01], [-3.57872956e-02, -2.01575994e-01, 1.96702540e-01, -3.88146490e-01, -1.12675011e-01, -5.58326393e-03, 3.82668376e-02, 4.22903523e-03, 2.17693895e-02, 1.08423218e-01], [-2.07205385e-01, -2.59337306e-01, 2.60428250e-01, -1.46623790e-01, -1.16065934e-01, 4.58090231e-02, 8.33489001e-04, -1.12877734e-01, -9.29935053e-02, 2.04128772e-03], [-1.77496448e-02, -1.68957263e-01, 1.55010790e-01, -1.70428276e-01, -6.28218353e-02, -5.46750277e-02, 4.30600196e-02, -4.52936888e-02, -4.25418466e-03, 9.09864157e-02], [-5.83440810e-02, -2.79580832e-01, 2.30377674e-01, -2.67966032e-01, -2.23203301e-01, -9.21733454e-02, 6.82215244e-02, -7.92368799e-02, -1.16557732e-01, 2.29901373e-01], [-1.66451573e-01, -1.27965242e-01, 8.50739703e-02, -1.67961344e-01, -9.98345464e-02, -4.91920859e-04, 5.27660176e-02, 1.20072439e-02, 2.08993554e-02, -1.30219087e-02], [-1.13255329e-01, -2.44278759e-01, 1.25342607e-01, -2.20574856e-01, -9.59239006e-02, 6.24536797e-02, 1.35844558e-01, -7.39271641e-02, -8.24161768e-02, 9.56056267e-02], [-1.57433510e-01, -2.29775012e-01, 2.63743341e-01, -2.54589796e-01, -3.25095691e-02, 3.24455313e-02, 1.64987221e-02, -1.35588739e-02, 8.02063942e-03, 1.65905043e-01], [-5.54538928e-02, -1.62802041e-01, 5.56794330e-02, -1.22614920e-01, 1.26078837e-02, 6.13490418e-02, 1.01778530e-01, -3.70408259e-02, -1.35597169e-01, 1.09969214e-01], [ 3.03655379e-02, -2.09959269e-01, 2.23499626e-01, -2.70942420e-01, -4.39836495e-02, 4.81552780e-02, 6.53428882e-02, -3.33468914e-02, -1.19314417e-01, 1.35095775e-01], [-2.63083801e-02, -2.31794327e-01, 2.83206254e-01, -2.27834284e-01, -8.86126012e-02, 1.02761075e-01, -2.06644014e-02, -1.48879681e-02, -1.43723190e-02, 8.48568678e-02], [ 4.40791063e-02, -2.51471341e-01, 2.46300191e-01, -2.48126388e-01, -6.29781634e-02, 1.08190909e-01, 8.61894041e-02, -5.60469851e-02, -8.09555054e-02, 1.27782807e-01], [ 1.95840001e-02, -2.24844486e-01, 2.38408834e-01, -1.66516751e-01, -3.48671526e-02, -2.76660472e-02, 1.11107826e-01, -1.05957389e-01, -1.13373876e-01, 9.40269902e-02], [-2.61919200e-02, -1.80716693e-01, 2.11239651e-01, -1.48656636e-01, 5.61328717e-02, -5.17362207e-02, 9.15940180e-02, -1.01114754e-02, -1.08649991e-01, 1.84490547e-01], [-7.72037655e-02, -2.22850427e-01, 7.46268854e-02, -1.06892399e-01, -4.03373092e-02, 1.43766273e-02, 1.40153512e-01, -6.31121248e-02, -8.07491690e-02, 1.55763984e-01], [ 1.85379386e-02, -2.84236252e-01, 2.49140605e-01, -2.10829794e-01, -6.11658916e-02, 5.26331663e-02, 4.89872694e-03, -2.52614915e-03, 2.61847600e-02, 1.39026672e-01], [-2.97715850e-02, -1.57704860e-01, 9.88153145e-02, -2.44123742e-01, 6.82935119e-03, 9.55739543e-02, 1.77407395e-02, -9.62423012e-02, -3.06694172e-02, -8.97377729e-04], [-7.62552842e-02, -2.15193182e-01, 1.10395223e-01, -1.49222597e-01, 6.31211698e-03, 1.07793435e-02, 9.34207141e-02, -2.95278970e-02, -5.32062501e-02, 5.18209413e-02], [-1.06743507e-01, -2.06863925e-01, 1.17143326e-01, -2.44419351e-02, 3.90787423e-03, 3.25693935e-03, 8.52572620e-02, -7.01214448e-02, -6.87168166e-02, 8.35259482e-02], [-1.37559652e-01, -2.67782450e-01, 9.01636854e-02, -2.43932471e-01, -2.19226599e-01, -3.02725025e-02, 1.46772370e-01, -6.44132793e-02, -1.46171749e-01, 2.30021209e-01], [-2.11456314e-01, -2.36388713e-01, 1.02335475e-01, -1.85400859e-01, -2.42546767e-01, -1.36911750e-01, 1.12711266e-01, -1.76483542e-01, 1.11777037e-02, 1.63012683e-01], [-1.13742456e-01, -2.88160145e-01, 3.06077778e-01, -8.28013793e-02, -6.80339709e-03, -7.89407492e-02, 9.80492905e-02, -1.51466310e-01, -1.06623411e-01, 1.75622880e-01], [-3.90224606e-02, -1.20289572e-01, 1.14604212e-01, -2.33187079e-01, -7.47915804e-02, -9.50895622e-03, 8.48786384e-02, -3.44136357e-03, -1.73418373e-02, 7.97818527e-02], [-1.76164024e-02, -1.63786843e-01, 1.69167370e-01, -3.05297762e-01, 9.24772024e-03, -1.02365762e-02, -7.10207000e-02, -2.40475833e-02, 7.44278207e-02, 8.85017216e-04], [-3.34250145e-02, -9.09714475e-02, 1.83904856e-01, -1.87584057e-01, -1.30171150e-01, 9.72316712e-02, 7.86449313e-02, -1.13042377e-01, -3.42860781e-02, 1.47009969e-01], [ 1.23950452e-01, -2.58233070e-01, 1.45953774e-01, -2.16587007e-01, -6.97525516e-02, 3.25547494e-02, 1.71580821e-01, -7.43409917e-02, -1.28298417e-01, 2.07221001e-01], [ 1.41256377e-02, -2.13005051e-01, 2.58257866e-01, -1.79573953e-01, -2.71227211e-04, 5.37357479e-03, 2.53414437e-02, 1.36982091e-02, -9.86021683e-02, 6.34129494e-02], [-9.53393206e-02, -2.51850098e-01, 7.02230334e-02, -2.41145603e-02, -4.18265685e-02, -5.73850162e-02, 1.40292212e-01, -8.98786932e-02, -3.68157625e-02, 1.81666166e-01], [-2.07753152e-01, -3.12528133e-01, 1.89828843e-01, -3.09255153e-01, -1.10048883e-01, 1.31216094e-01, 1.19376153e-01, -9.24888998e-02, -2.41679773e-02, 1.30824149e-01], [-1.73711367e-02, -1.57934368e-01, 2.63651330e-02, -1.49949104e-01, -4.08344306e-02, 2.76439488e-02, 1.11482270e-01, -5.71934357e-02, 1.20028351e-02, 6.82912618e-02], [-1.63935825e-01, -1.96109533e-01, 5.31068258e-02, -3.45816255e-01, -1.36397362e-01, -5.61913103e-03, 8.17488879e-02, -8.61469805e-02, -4.42591459e-02, 1.56713396e-01], [-9.78560969e-02, -2.81225532e-01, 2.51580268e-01, -2.79330462e-01, -5.54226488e-02, 2.85820842e-01, 1.37798816e-01, 5.53442948e-02, -9.66755450e-02, 7.41188601e-02], [ 2.18282156e-02, -2.45405048e-01, 2.76525080e-01, -2.43340984e-01, -9.32168737e-02, -1.03917256e-01, 1.76020056e-01, -1.96946055e-01, -1.05199814e-01, 1.43517733e-01], [ 1.91491842e-03, -2.14782000e-01, 2.14402735e-01, -1.35252386e-01, -2.68143285e-02, 2.60272920e-02, 1.14399910e-01, -7.09561855e-02, -4.79233414e-02, 9.08451229e-02], [ 4.69092131e-02, -1.37026012e-01, 1.17278881e-01, -1.10016525e-01, 4.56601195e-03, -6.45425990e-02, 1.03515517e-02, -1.12781025e-01, -9.09203887e-02, 2.90042795e-02], [-1.14894003e-01, -2.83115238e-01, 1.25633582e-01, -1.08892918e-01, -6.60854280e-02, 1.34425879e-01, 1.28590137e-01, -6.46811128e-02, -1.05066359e-01, 7.78829530e-02], [-1.34159416e-01, -1.16149344e-01, 1.10704780e-01, -1.82501853e-01, -1.11571878e-01, -1.78051442e-02, -7.05918521e-02, -8.96947682e-02, -2.43886709e-02, 1.28249973e-01], [-8.70557725e-02, -3.13068718e-01, 2.05983564e-01, -3.60752761e-01, -2.52963424e-01, 9.56458598e-02, 4.99115959e-02, -2.00084984e-01, -2.05507725e-02, -4.71021235e-03], [ 2.53530964e-02, -3.20023328e-01, 9.63085815e-02, -1.85679093e-01, -7.90637583e-02, 1.14473905e-02, 2.17339128e-01, -1.33131921e-01, -1.09921828e-01, 1.76644355e-01], [-8.13475475e-02, -2.46280625e-01, 1.63768873e-01, -2.03582734e-01, -1.11366063e-01, 5.85865155e-02, 1.05814189e-01, -1.29950434e-01, 4.80654463e-02, 9.21947211e-02], [ 2.88170986e-02, -2.13467956e-01, 2.08763123e-01, -2.61732817e-01, -5.98498695e-02, 1.69790924e-01, 8.89813602e-02, -2.64669470e-02, -9.32441205e-02, 1.07261270e-01], [-2.18984842e-01, -2.97292233e-01, 2.52088934e-01, -2.72813499e-01, -2.83467293e-01, -3.84202600e-03, 1.62944064e-01, -1.83288395e-01, 6.60075694e-02, 1.69622019e-01], [-8.47560316e-02, -2.25378931e-01, 9.83616710e-02, -2.56649733e-01, -9.33569446e-02, 1.18389249e-01, 1.26382694e-01, 5.40108047e-03, -8.15389231e-02, 7.32447356e-02], [-2.15248689e-01, -2.89983273e-01, 2.81054348e-01, -2.34807819e-01, -8.91850218e-02, 8.13726410e-02, 2.78547443e-02, -1.62431747e-01, -8.10221583e-02, 1.28494352e-01], [-8.11575055e-02, -2.12961227e-01, 1.35927528e-01, -1.84795290e-01, 1.68305412e-02, 7.67179057e-02, 1.78498998e-02, 7.99339041e-02, -5.19917384e-02, 7.88615197e-02], [-1.39584109e-01, -1.53254420e-01, 1.40996322e-01, -1.42457187e-01, -1.61698401e-01, -5.59826791e-02, 5.54615706e-02, -9.48763862e-02, -3.23499776e-02, 7.35653192e-03], [-1.77557528e-01, -2.59482116e-01, 6.93101287e-02, -2.48186603e-01, -2.21639857e-01, 1.79100707e-02, 6.63381070e-02, -7.36963004e-02, -8.53785872e-03, 2.26362363e-01], [-4.66247573e-02, -1.87552884e-01, 7.50766397e-02, -1.83468118e-01, -6.94032758e-02, 7.31687471e-02, 1.05809800e-01, -4.82021868e-02, -1.43606484e-01, 7.33558759e-02], [-5.61062619e-02, -2.11781189e-01, 9.42066908e-02, -1.46039844e-01, -4.92455214e-02, 2.57240999e-02, 1.21175073e-01, -2.49509886e-03, -8.16714466e-02, 9.40574110e-02], [-8.77167284e-03, -1.11339316e-01, 2.08961472e-01, -2.24435702e-01, -5.41306883e-02, 8.92893150e-02, 1.43988505e-02, -9.36872140e-03, -3.50931883e-02, 2.20432207e-02], [-1.67946219e-02, -1.56142265e-01, 2.78167158e-01, -2.60964632e-01, -3.65047120e-02, 1.09166548e-01, 5.87949678e-02, -2.11432949e-03, -6.34969622e-02, 3.21010947e-02], [-6.98461309e-02, -2.51005530e-01, 1.15319267e-01, -4.03009146e-01, -1.53814182e-01, 1.31292313e-01, -8.88947397e-04, -1.04614519e-01, 3.27713639e-02, 8.19778293e-02], [-4.96144891e-02, -2.34044015e-01, 8.73164386e-02, -1.27475381e-01, -3.70700210e-02, -6.38843030e-02, 5.96667975e-02, 6.24071062e-03, -1.49838597e-01, 1.23565733e-01], [-1.61013320e-01, -2.79918849e-01, 8.42825323e-02, -2.26443440e-01, -4.07436192e-02, 9.87039953e-02, 9.74905342e-02, -1.89107805e-02, -3.60541046e-02, 1.38564959e-01], [-1.54935226e-01, -2.74246246e-01, 3.52338314e-01, -2.80018985e-01, -1.63894460e-01, 3.03545818e-02, -1.97691619e-02, -6.44856766e-02, -6.44719303e-02, 1.43361017e-01], [-2.89845139e-01, -1.95270896e-01, 1.58085182e-01, -3.41276884e-01, -1.88264221e-01, 1.11826345e-01, 7.22953156e-02, -9.40560848e-02, -1.72010846e-02, 6.65165558e-02], [ 2.80667376e-03, -2.22258165e-01, 3.22467759e-02, -2.36251205e-01, -7.11319596e-03, 8.50130171e-02, 1.23804763e-01, -1.12940632e-02, -6.83347136e-02, 5.93173727e-02], [-1.05941780e-01, -2.20807061e-01, 2.88082898e-01, -2.98589587e-01, -9.30554122e-02, 1.04942888e-01, 3.52538377e-02, -9.22118574e-02, 2.07180530e-03, 3.60559747e-02], [ 2.03453898e-02, -2.02152938e-01, 2.23098144e-01, -1.03950635e-01, -1.36403218e-02, -1.37776323e-03, 7.07350224e-02, -7.44753927e-02, -6.00884408e-02, 2.88838595e-02], [-3.67555171e-02, -3.20030272e-01, 9.78964120e-02, -1.48517996e-01, -9.19343829e-02, 1.52479857e-01, 1.42603204e-01, 3.52594629e-02, 2.06234530e-02, 6.88164756e-02]], dtype=float32)>, 1: <tf.Tensor: shape=(64, 10), dtype=float32, numpy= array([[ 6.75567240e-02, -2.38914177e-01, 1.33755952e-01, -2.07327485e-01, -9.19160843e-02, 5.89535609e-02, 1.93401918e-01, -1.64012074e-01, -3.17274481e-02, 1.14088662e-01], [-2.23600149e-01, -2.49316663e-01, 1.22388750e-01, -3.41853201e-01, -1.52160853e-01, 8.29158872e-02, 3.99679951e-02, -1.16016984e-01, -1.99246258e-02, 1.21608429e-01], [-1.53935015e-01, -3.47085416e-01, 2.30311319e-01, -2.99614340e-01, -1.39243215e-01, 7.35733509e-02, 1.06897958e-01, -9.20633078e-02, -2.49818861e-02, 2.51860380e-01], [-1.09977022e-01, -1.20218769e-01, 1.51068568e-01, -1.49315774e-01, -4.17811833e-02, 3.39379460e-02, 1.07125811e-01, -8.54823291e-02, -1.41600668e-01, 7.99349621e-02], [-1.18387341e-01, -1.79526880e-01, 2.39438504e-01, -2.55786717e-01, -2.42254473e-02, 6.90134019e-02, -1.44913606e-02, -1.83255941e-01, 4.52577323e-03, -2.48518493e-02], [-6.42177686e-02, -2.41901368e-01, 1.50756523e-01, -2.32527256e-01, -1.02800429e-02, -3.01985294e-02, 1.02250203e-01, -4.66635041e-02, -7.00209141e-02, 8.70777220e-02], [-3.93515602e-02, -1.80711269e-01, 3.10113356e-02, -1.63825929e-01, 1.16567761e-02, -1.84918195e-03, 1.57905579e-01, -1.40320407e-02, 6.27957284e-04, 9.89601538e-02], [-9.30578932e-02, -2.11971626e-01, 1.43260211e-01, -1.89392775e-01, -3.45423371e-02, 8.55274871e-02, 2.28122361e-02, -7.40249753e-02, -7.42083490e-02, -2.24770978e-03], [-1.15170747e-01, -1.46501660e-01, 1.15974933e-01, -2.14340284e-01, -1.62206702e-02, -5.78544959e-02, -2.78132670e-02, -3.13530304e-03, -5.26326448e-02, 1.29651099e-01], [-2.38952264e-02, -2.04561457e-01, 1.04315609e-01, -2.75441974e-01, -3.13567668e-02, 8.28484371e-02, 1.18773624e-01, 1.61623191e-02, -4.39512245e-02, 8.09860602e-02], [-9.01764333e-02, -6.08202852e-02, 9.35570225e-02, -1.81577951e-01, -3.92897651e-02, 1.72614586e-02, 1.68434158e-02, 1.97312720e-02, 4.90831211e-03, -8.68381001e-03], [-5.88432439e-02, -1.96920246e-01, 9.43311602e-02, -2.19087839e-01, -7.72496611e-02, 1.98269002e-02, -3.19615677e-02, -5.05579785e-02, 5.13686463e-02, 1.66076794e-02], [ 9.41839814e-03, -2.58791268e-01, 1.48562372e-01, -1.96932435e-01, -1.38596475e-01, 1.45630687e-01, 9.16977003e-02, -6.70069009e-02, -4.51352410e-02, 8.43038410e-03], [ 2.93919519e-02, -2.10931510e-01, 2.25182056e-01, -1.35816023e-01, -3.68131325e-02, -3.75051200e-02, -4.70872223e-03, -3.74109745e-02, -1.19374879e-01, 6.64644539e-02], [ 4.75103408e-02, -1.86894327e-01, 1.51290238e-01, -2.79287338e-01, -1.53363720e-01, 1.47845596e-02, 1.75720379e-01, -5.02634346e-02, -5.87904453e-03, 1.18566155e-01], [-3.42349894e-02, -1.89912662e-01, 1.13818869e-01, -1.54048502e-01, 7.97697529e-03, -3.13553810e-02, 1.12695418e-01, -1.28714517e-02, -1.44971073e-01, 1.54360116e-01], [-1.37819290e-01, -2.06382483e-01, 2.62454152e-01, -2.56959349e-01, -9.86221433e-03, 1.18425936e-01, 3.65710557e-02, -1.97922252e-02, -2.31456757e-03, 1.21300094e-01], [ 4.10704911e-02, -1.95815235e-01, 1.57558531e-01, -2.89344847e-01, -5.17649353e-02, 1.08768210e-01, -1.35319531e-02, -1.15863279e-01, -1.09515153e-02, 1.19357407e-02], [-1.43476054e-01, -2.10264921e-01, 2.32700467e-01, -2.91506439e-01, -2.21236050e-03, 1.20373160e-01, 9.87835005e-02, -2.51563042e-02, -8.44163150e-02, 4.65912372e-03], [ 9.55003127e-02, -1.94997519e-01, 1.45701885e-01, -2.97760904e-01, -4.06291261e-02, 1.23578221e-01, 7.91319758e-02, -9.34391096e-03, -1.14004582e-01, 4.87460196e-02], [-6.28919154e-03, -1.93234399e-01, 2.07127661e-01, -1.23730406e-01, -8.91654193e-02, 7.27545470e-02, 1.67791069e-01, -1.97549075e-01, -1.20808989e-01, 8.30601826e-02], [-1.49995640e-01, -2.76141942e-01, 7.16349334e-02, -2.07400069e-01, -2.80750170e-03, 2.61893049e-02, 2.00932533e-01, 2.86746919e-02, -6.80275112e-02, 1.04255706e-01], [-3.76789831e-03, -2.46453315e-01, 1.70502365e-01, -3.04748714e-01, -1.11648008e-01, 1.09304577e-01, 1.46715984e-01, -9.28808898e-02, -1.20145857e-01, 1.25662416e-01], [-7.73534104e-02, -2.67341197e-01, 1.33661300e-01, -2.26713523e-01, -2.06280217e-01, 2.44263783e-02, 1.44257456e-01, -1.41797826e-01, -2.30760947e-02, 1.59830034e-01], [-1.03204787e-01, -1.79054022e-01, 1.18528701e-01, -1.57337412e-01, -7.18231797e-02, -1.95176750e-02, 6.60207495e-02, 2.03975663e-03, -4.26177606e-02, 1.90442830e-01], [-1.79974467e-01, -1.97920576e-01, 1.04850456e-01, -1.63971022e-01, -9.94304940e-02, -2.10145488e-02, -1.59724206e-02, -7.93610737e-02, -7.53567666e-02, 6.35726154e-02], [-7.11826012e-02, -2.34555811e-01, 2.13445142e-01, -3.07684302e-01, -4.45934460e-02, 1.10307448e-01, 7.35006258e-02, 5.52741475e-02, -1.58808827e-01, 5.92289120e-02], [-2.13126108e-01, -1.64581329e-01, 4.67392690e-02, -1.46799356e-01, -1.03923678e-01, 2.69821379e-02, 9.77789611e-02, -1.22085780e-01, -7.33418316e-02, 1.27354473e-01], [-1.08663574e-01, -1.98505193e-01, 1.41354769e-01, -1.41805619e-01, -6.71410561e-02, 7.71894306e-02, 1.13328248e-01, -5.07075340e-02, -1.38026953e-01, 5.31041361e-02], [-1.32619187e-01, -1.72933936e-01, 1.28968596e-01, -1.82453096e-01, -4.75453287e-02, 9.17547569e-02, 7.28915930e-02, -6.51091859e-02, -1.33371800e-01, 6.43559098e-02], [ 1.07341073e-03, -1.84513748e-01, 3.40713449e-02, -1.62618071e-01, 5.27607352e-02, 2.40905657e-02, 1.33244365e-01, -7.88654014e-03, -3.12274247e-02, 1.29569262e-01], [-4.94147651e-02, -2.04918981e-01, 5.29308319e-02, -2.47300029e-01, -1.31541356e-01, 6.23399988e-02, 6.43031001e-02, -5.82937337e-02, -3.84260975e-02, 9.14722979e-02], [-2.31764652e-02, -2.01847285e-01, 1.36055797e-01, -1.50709137e-01, -6.48210198e-02, 1.17176846e-02, 1.18309550e-01, -1.26468271e-01, -1.10086706e-02, 6.10472038e-02], [-1.31686330e-01, -1.88606307e-01, 5.88058047e-02, -2.45802298e-01, -8.22002739e-02, 7.32430220e-02, 2.33802926e-02, 5.33343442e-02, -4.92597073e-02, -4.78738733e-02], [-2.77145095e-02, -7.30694160e-02, 6.33261651e-02, -1.62415922e-01, -8.39199871e-03, 7.93164074e-02, 1.12110309e-01, -3.17870900e-02, -1.02302730e-01, 8.16293359e-02], [-2.33535245e-02, -2.35255420e-01, 1.09575137e-01, -2.99337596e-01, -9.92065594e-02, 9.49253887e-02, 6.46477658e-03, -7.73567110e-02, -3.07335034e-02, 8.38225335e-02], [ 4.21159938e-02, -1.82390302e-01, 1.72212362e-01, -1.88702613e-01, 3.05568986e-02, -8.93439651e-02, 4.43341769e-02, -1.82540379e-02, -7.46393129e-02, 9.38700587e-02], [-3.24128926e-01, -3.46765816e-01, 1.40220657e-01, -3.02282214e-01, -1.64297551e-01, 3.75131369e-02, 1.05158560e-01, -1.28027707e-01, -3.64448056e-02, 1.48398817e-01], [-9.46170017e-02, -1.27747789e-01, 1.01317704e-01, -1.31414354e-01, 2.40935534e-02, 4.38534915e-02, 5.19645400e-02, -6.66568354e-02, -1.02555230e-01, 1.30323857e-01], [-3.93612012e-02, -1.81893080e-01, 2.32378110e-01, -2.48832300e-01, -2.94160806e-02, 1.27097592e-01, 1.29024684e-03, -5.54355457e-02, -3.50462496e-02, 1.62917644e-01], [-5.88566437e-03, -2.51869023e-01, 1.37316018e-01, -2.04826683e-01, -4.83629592e-02, -4.02906351e-02, 1.58909738e-01, -9.28904042e-02, -1.49646878e-01, 1.23233311e-01], [ 1.50584102e-01, -1.82133600e-01, 1.45644158e-01, -2.02355444e-01, -3.01427692e-02, 1.39219448e-01, 6.36230707e-02, -2.22080909e-02, -6.22782111e-02, 4.97137979e-02], [-8.17728937e-02, -1.90628260e-01, 1.78402692e-01, -2.39471644e-01, -1.54569790e-01, 1.57149822e-01, 3.74688730e-02, -7.91701078e-02, -3.61192375e-02, 7.98084810e-02], [-8.21674317e-02, -1.73789293e-01, 3.86341810e-02, -2.17232525e-01, -9.88442898e-02, -1.70633346e-02, 1.09598055e-01, -3.63032296e-02, -6.62199855e-02, 4.37744111e-02], [-8.46960321e-02, -1.05847187e-01, 6.20741844e-02, -1.02550842e-01, -4.40515429e-02, 3.96654233e-02, 4.43543531e-02, -5.37615344e-02, -7.08982646e-02, 2.58079097e-02], [-1.10534452e-01, -2.73831189e-01, 2.41692096e-01, -2.95142770e-01, -1.25000581e-01, 1.05496168e-01, 1.00305468e-01, -1.80569999e-02, -1.85078084e-02, 8.71269554e-02], [ 4.35274318e-02, -2.66818404e-01, 2.77568698e-01, -1.63077131e-01, 8.64195600e-02, 3.57484072e-02, 2.33793333e-02, 9.71107408e-02, -3.17685902e-02, 1.08533472e-01], [-1.03916764e-01, -2.57931769e-01, 1.31162852e-01, -1.68040425e-01, -1.77489594e-01, 2.28280649e-02, 1.09173112e-01, -1.58080742e-01, -6.75074980e-02, 1.83772370e-01], [ 7.00471699e-02, -1.44712493e-01, 5.51679246e-02, -1.47630006e-01, -6.52111471e-02, -4.57328409e-02, 2.70138960e-02, -1.18774213e-01, -8.94630551e-02, 8.67242515e-02], [-1.05358377e-01, -1.44616842e-01, 1.54398665e-01, -1.96894646e-01, -1.71607547e-02, 3.31663117e-02, 2.50310898e-02, 1.51388571e-02, -3.87620144e-02, 2.03875154e-02], [-1.87340379e-01, -2.15252638e-01, 1.39404327e-01, -2.53099680e-01, -1.66907489e-01, 3.76047194e-02, 1.07667297e-01, -1.07199512e-01, -7.23691881e-02, 1.68525219e-01], [ 2.01733168e-02, -1.50373489e-01, 1.76286429e-01, -7.97204971e-02, -1.12195268e-01, -7.87620768e-02, 1.21824570e-01, -5.35783581e-02, -1.59563199e-01, 9.11581069e-02], [-3.14304531e-02, -2.35860884e-01, 6.42290190e-02, -2.01830775e-01, -8.17677304e-02, 8.04069340e-02, 1.83426544e-01, -4.26740199e-02, -4.14486527e-02, 7.44916052e-02], [-1.30108260e-02, -2.16072693e-01, 1.88388437e-01, -1.95033506e-01, -9.41881686e-02, 3.52058709e-02, 1.16087236e-01, -4.11959216e-02, -1.37238830e-01, 5.48713952e-02], [-1.86875999e-01, -1.74092755e-01, 1.44259140e-01, -2.21626043e-01, -6.76412433e-02, 2.27241032e-02, 2.78104097e-04, -5.67700416e-02, -3.89099047e-02, 5.84937111e-02], [-2.78417021e-02, -2.42510140e-01, 2.46906400e-01, -2.44772315e-01, -5.41119277e-03, 7.76698217e-02, 1.32741779e-01, -5.95767014e-02, -1.00354224e-01, 1.16295107e-01], [ 2.81312205e-02, -2.12746531e-01, 1.18400827e-01, -2.01335326e-01, -1.08814105e-01, -8.70871171e-03, 1.15727030e-01, -5.08784950e-02, -5.90119511e-02, 9.64380130e-02], [-1.09740384e-01, -2.43449017e-01, 2.03824818e-01, -3.86486232e-01, -1.03655666e-01, -2.66943872e-03, -1.29154325e-03, -9.31151584e-02, 4.89664078e-02, 1.13949358e-01], [-6.35606200e-02, -2.54489332e-01, 2.56252028e-02, -2.81080276e-01, -7.74140954e-02, 2.45606564e-02, 5.71488701e-02, 3.81116010e-03, -1.45976931e-01, 3.64610404e-02], [ 1.00376189e-01, -2.20337451e-01, 8.64842683e-02, -1.78378254e-01, -6.71998113e-02, 4.18653563e-02, 2.58925885e-01, -1.50193900e-01, -6.32833838e-02, 6.70849681e-02], [-8.17512572e-02, -2.87286878e-01, 6.71994686e-02, -1.14006430e-01, 5.49807847e-02, 1.19937584e-01, 1.02814890e-01, 3.91088054e-02, -5.14580943e-02, 1.02065146e-01], [-1.33380964e-02, -2.44644374e-01, 2.17214927e-01, -1.48414731e-01, -7.70019591e-02, -1.34491585e-02, 7.06726164e-02, -4.43194248e-02, -3.47995572e-02, 9.49392319e-02], [ 9.09983367e-02, -2.26088285e-01, 1.56397849e-01, -2.00250998e-01, 6.76592812e-03, 6.34500533e-02, 1.10073805e-01, -5.11747636e-02, -7.38727897e-02, 1.07999891e-01], [-7.74010867e-02, -2.58765668e-01, 1.42947018e-01, -1.16328359e-01, -4.54804264e-02, -6.79061860e-02, 1.80663586e-01, -5.15852086e-02, -1.43913418e-01, 1.31981447e-01]], dtype=float32)>, 2: <tf.Tensor: shape=(64, 10), dtype=float32, numpy= array([[-2.92910397e-01, -2.47621059e-01, 6.90507963e-02, -2.19161585e-01, -2.21360564e-01, -9.72218812e-03, 8.65268558e-02, -1.84303567e-01, 7.39633292e-02, 9.26390439e-02], [-2.76300550e-01, -2.86523432e-01, 1.03790835e-01, -2.50321865e-01, -3.67840767e-01, 5.52606620e-02, 1.56405732e-01, -2.40908995e-01, -2.32718661e-02, 1.47031933e-01], [-8.73501897e-02, -2.46659130e-01, 2.13791877e-01, -3.35097313e-01, -8.17968696e-03, 1.51596636e-01, 1.48926795e-01, 2.56685019e-02, -1.42294928e-01, 1.58504218e-01], [-7.83732757e-02, -2.01639578e-01, 1.12231679e-01, -1.91021413e-01, -3.09330877e-02, 1.49019212e-02, 2.91242450e-02, -1.07948389e-02, -1.09200925e-02, 1.09643362e-01], [-1.44622475e-02, -1.94366693e-01, 1.39955163e-01, -5.29873073e-02, 4.69535217e-03, -5.78343868e-02, 1.24284156e-01, -7.85315782e-03, -6.92758560e-02, 1.56148106e-01], [-2.15620652e-01, -2.79937387e-01, 2.46239170e-01, -3.61058503e-01, -8.46953541e-02, 8.27162415e-02, 5.81266582e-02, -6.95971325e-02, 4.19765711e-03, 5.22745773e-02], [-6.77132830e-02, -2.19626561e-01, 6.83878735e-02, -7.11585283e-02, -1.12009250e-01, -1.61485821e-02, 1.37567550e-01, -9.79662314e-02, -1.02077886e-01, 1.62132189e-01], [-4.25315201e-02, -1.52483866e-01, 1.52330756e-01, -2.21901804e-01, -9.88871157e-02, 1.32006407e-02, 1.17788263e-01, -1.45564049e-01, -1.16822071e-01, 9.02762711e-02], [-1.49164647e-01, -1.15676403e-01, 1.00291103e-01, -2.29233235e-01, -2.76894439e-02, 4.52283993e-02, 2.61685997e-03, 4.93768603e-03, -2.07472444e-02, -1.40195005e-02], [-1.25718370e-01, -1.63824499e-01, 1.27739802e-01, -1.51549771e-01, -2.92051807e-02, 7.72437304e-02, 1.64590597e-01, -9.08653513e-02, -1.62190646e-01, 1.16016909e-01], [-1.48351267e-01, -2.31614679e-01, 1.77401215e-01, -2.59321064e-01, -2.35821828e-01, -2.85175219e-02, 2.08227664e-01, -1.90595418e-01, 2.49112062e-02, 1.31946355e-01], [-8.27215165e-02, -2.29101837e-01, 1.11695848e-01, -1.74895227e-01, 4.56438474e-02, 5.43713868e-02, 9.87353697e-02, 2.62146965e-02, -1.30267262e-01, 1.14205681e-01], [-1.08721137e-01, -2.72843093e-01, 9.31099802e-02, -1.23697445e-01, -9.63272154e-03, 1.27321407e-02, 6.98904097e-02, -9.42568481e-03, -8.82773101e-02, 1.25001103e-01], [-1.00479007e-01, -2.81627387e-01, 1.77257553e-01, -1.38115942e-01, -2.13532656e-01, 1.28176063e-03, 2.37812802e-01, -1.71448037e-01, 6.75702170e-02, 1.49718404e-01], [-1.39466435e-01, -3.68360281e-01, 3.77481937e-01, -3.58653575e-01, -1.56710133e-01, 3.23769450e-02, 1.39575779e-01, -1.87596664e-01, -4.46216241e-02, 1.02873258e-01], [-2.04328567e-01, -3.08403432e-01, 2.15267628e-01, -4.57034171e-01, -3.16666394e-01, 1.56880215e-01, 2.55806923e-01, -1.68201983e-01, -1.20759964e-01, 1.54574573e-01], [-1.71269506e-01, -2.17248201e-01, 3.59796584e-02, -2.40118355e-01, -1.63183182e-01, -8.52251053e-03, 1.41886458e-01, -1.50629103e-01, -3.87271047e-02, 1.90065891e-01], [-4.66681868e-02, -1.78006977e-01, 2.20160872e-01, -2.23264426e-01, -7.44061023e-02, -1.20238170e-01, 4.84996289e-02, 2.61790790e-02, 3.07720937e-02, 1.26800299e-01], [-7.09847361e-02, -1.31401077e-01, 1.01098686e-01, -1.61830604e-01, -1.67847320e-01, 1.49966776e-03, 5.40902689e-02, -1.19909227e-01, -3.46334092e-02, 4.17901911e-02], [-4.81188744e-02, -2.31777966e-01, 1.33515656e-01, -1.77572027e-01, -9.27763805e-02, 1.19939223e-01, 1.91938803e-01, 1.30975451e-02, -5.63915595e-02, 1.85789168e-01], [-2.18921661e-01, -2.48723641e-01, 2.59925425e-01, -2.24437684e-01, -1.10277042e-01, -5.78821786e-02, 1.41703337e-03, -1.20933898e-01, -1.25650913e-01, 7.68510327e-02], [-1.78173445e-02, -1.89112589e-01, 1.07707471e-01, -1.89524472e-01, -1.31092966e-03, -3.67355272e-02, 4.30463105e-02, -5.70860542e-02, -4.75289822e-02, 7.66582564e-02], [-1.72761708e-01, -2.45659009e-01, 1.66983366e-01, -2.54466474e-01, -1.17533267e-01, -6.65320829e-02, 4.86670732e-02, -1.19520187e-01, 2.17647851e-02, 8.45992565e-02], [ 3.37264203e-02, -2.43884057e-01, 1.37869015e-01, -1.90368980e-01, -6.26042187e-02, 1.09830603e-01, 1.30694136e-01, -4.54947837e-02, -1.65158048e-01, 7.82282799e-02], [ 6.20651506e-02, -2.80856282e-01, 1.39889494e-01, -2.16678217e-01, -2.06417799e-01, 1.38500735e-01, 1.23121470e-01, -1.58497691e-01, -5.24034649e-02, 9.95990485e-02], [-2.10504338e-01, -2.77220428e-01, 1.77865088e-01, -1.43282861e-01, -3.75899374e-02, 2.29452960e-02, 1.10701233e-01, -1.64192513e-01, -8.81378055e-02, 1.28708452e-01], [ 1.08307555e-01, -2.17978150e-01, 1.52921289e-01, -9.37823802e-02, 1.64098274e-02, -4.94194441e-02, 1.54389471e-01, -8.04066584e-02, -7.47242048e-02, 1.96253464e-01], [ 1.82701647e-02, -2.41666913e-01, 3.30916643e-01, -2.89392531e-01, 2.86299177e-02, 1.80207133e-01, 1.20492458e-01, -3.68874334e-03, -1.09514087e-01, 2.38499939e-02], [-2.18268231e-01, -1.96707711e-01, 6.73187375e-02, -3.34713846e-01, -1.60789937e-01, -3.07322592e-02, -7.08551779e-02, -1.07558824e-01, 7.43071511e-02, 9.46420357e-02], [-8.31968337e-02, -1.67009011e-01, 2.08477899e-01, -3.24087173e-01, -4.97262701e-02, 9.22416002e-02, 3.01967561e-03, -1.68635733e-02, -2.21733004e-02, 3.29319015e-02], [-7.85755068e-02, -2.02665269e-01, 1.14596143e-01, -2.07859248e-01, -3.78912464e-02, 3.91333997e-02, -8.18152726e-03, -3.25387642e-02, 5.62441349e-03, -7.13631511e-04], [-1.08744502e-01, -1.73820540e-01, 6.82278052e-02, -1.93427488e-01, 5.87783158e-02, 1.29057392e-02, 1.13731325e-01, -8.31326470e-02, -9.97855291e-02, 1.88317955e-01], [-8.18054378e-02, -1.91538498e-01, 1.39281318e-01, -2.40328074e-01, -1.31923303e-01, -2.29768604e-02, 1.07800342e-01, -1.91812992e-01, 2.21445858e-02, 9.91273448e-02], [-9.42839384e-02, -2.12466210e-01, 2.76104271e-01, -3.45111758e-01, -1.24919176e-01, 4.16840166e-02, 2.50887424e-02, -6.03887811e-02, -3.20043415e-02, 1.54104218e-01], [-1.50097802e-01, -2.29517475e-01, 1.50762454e-01, -2.29443863e-01, -8.23347345e-02, 7.48579353e-02, 1.36823013e-01, 4.61617485e-02, -4.60946709e-02, 1.48208708e-01], [-1.01659216e-01, -2.45093614e-01, 2.64105350e-01, -1.40816480e-01, -8.19988772e-02, 1.06125474e-02, -2.11667195e-02, -7.13749677e-02, -5.12884557e-03, 1.55431107e-02], [ 5.66820987e-02, -2.85035551e-01, 1.57222211e-01, -2.06601799e-01, -6.22553527e-02, 7.65531808e-02, 1.53478608e-01, -1.06203757e-01, -9.18237269e-02, 1.17344812e-01], [-7.37495795e-02, -1.94213808e-01, 1.85885191e-01, -1.52571782e-01, -6.66468963e-03, 1.06399871e-01, 7.87580088e-02, -5.24353087e-02, -1.31784335e-01, 5.49257547e-02], [ 2.28347629e-02, -2.54500806e-01, 2.07621306e-01, -1.90982759e-01, -9.54620987e-02, 6.86145127e-02, 1.20372787e-01, -1.03277802e-01, -1.61320299e-01, 4.03868109e-02], [-1.21820644e-01, -2.66319275e-01, 1.23099439e-01, -3.31621736e-01, -1.20969050e-01, 1.44172966e-01, 1.29650146e-01, -1.19315289e-01, 5.49161732e-02, 1.10396475e-01], [-5.71482629e-02, -2.34085366e-01, 6.66581318e-02, -1.63312227e-01, 1.33750141e-02, 1.14454858e-01, 2.04993337e-01, -3.72511819e-02, -4.19050679e-02, 5.16368002e-02], [-1.75331831e-01, -2.00746834e-01, 1.66912228e-01, -1.73353434e-01, 8.99408013e-04, 3.89906317e-02, 7.84541667e-03, -7.18510151e-03, -4.00707871e-02, 6.10330254e-02], [-1.34657100e-01, -3.50459218e-01, 1.85896263e-01, -3.04327965e-01, -4.08849627e-01, 5.86151481e-02, 1.27343044e-01, -2.63043374e-01, 4.35868055e-02, 9.73422453e-02], [-2.33950466e-02, -1.45173594e-01, 1.93431556e-01, -5.95074520e-03, 5.47398552e-02, -1.01979636e-02, 3.94684449e-02, 4.44147736e-05, -6.52863458e-02, 1.25901237e-01], [-2.33400270e-01, -2.77099907e-01, 1.70012623e-01, -3.25400025e-01, -3.17534387e-01, 7.07024634e-02, 1.21218272e-01, -2.24716857e-01, 6.35834858e-02, 7.80200660e-02], [-1.75186858e-01, -3.02092314e-01, 1.89863861e-01, -1.57291308e-01, -2.84965783e-02, 1.14309087e-01, 3.01322564e-02, -7.36754090e-02, -3.88177335e-02, 2.13922635e-02], [-6.73210472e-02, -9.82739702e-02, 1.56227544e-01, -7.04675466e-02, 2.06332654e-04, 8.09203386e-02, 1.86109096e-02, -5.64795882e-02, -4.57075164e-02, 7.62852132e-02], [-2.20414430e-01, -1.68171704e-01, 2.03580216e-01, -1.94860473e-01, -1.40957206e-01, 6.51835352e-02, -7.43473694e-02, -1.47300467e-01, -5.23896329e-03, -1.00979730e-01], [-1.69799864e-01, -1.60041898e-01, 1.47675991e-01, -1.66943923e-01, -8.50687325e-02, -3.23204324e-02, 2.50914507e-02, -1.07992783e-01, -2.50843614e-02, 9.48710740e-03], [ 2.77388208e-02, -1.21045031e-01, 6.39347658e-02, -1.24417767e-01, 4.94709574e-02, 8.18249211e-03, 8.22345167e-02, -3.93214822e-02, -1.23746172e-01, 1.42454475e-01], [-1.21482573e-02, -2.33502582e-01, 1.86448812e-01, -2.69238621e-01, -2.47626334e-01, 3.08376700e-02, 8.53330120e-02, -1.29280567e-01, -9.93426740e-02, 1.24004230e-01], [ 2.15106905e-02, -2.29617596e-01, 1.83588862e-01, -1.80853486e-01, -9.27157179e-02, -3.12608406e-02, 7.27558658e-02, -1.31818771e-01, -1.02937244e-01, 4.77900580e-02], [-1.34207189e-01, -1.58990055e-01, 1.32712230e-01, -2.07536668e-01, -9.75533277e-02, -3.75058427e-02, 5.24073131e-02, -1.12000749e-01, -5.42353205e-02, -6.10221922e-03], [-1.66911319e-01, -3.29504818e-01, 2.26329386e-01, -3.18372309e-01, -2.58815914e-01, 7.90991262e-02, 1.71128705e-01, -2.09883437e-01, -2.27044821e-02, 1.64031625e-01], [-3.81070375e-02, -1.52042240e-01, 5.88850938e-02, -6.83134273e-02, 3.95197235e-02, -2.34651025e-02, 1.15644529e-01, -4.76871654e-02, -1.01254530e-01, 1.56307653e-01], [-1.33788943e-01, -1.94588646e-01, 8.13895538e-02, -1.60082981e-01, -1.48450211e-01, -5.13584614e-02, 3.63013297e-02, -1.29578978e-01, 8.91127512e-02, 8.92005116e-03], [-8.50321800e-02, -2.60567844e-01, 1.05235487e-01, -2.05197290e-01, 8.70009735e-02, 1.62730500e-01, 1.66213632e-01, 1.05643667e-01, -1.50912479e-01, 7.83244222e-02], [-2.36878172e-02, -2.53476143e-01, 1.39658526e-01, -1.86545521e-01, -1.49292707e-01, 4.54190001e-02, 1.21261932e-01, -1.10977873e-01, -5.55660576e-02, 1.63696229e-01], [-5.03422916e-02, -1.65697366e-01, 2.60710623e-02, -1.71452627e-01, -5.65209687e-02, 8.02249461e-02, 1.27547771e-01, -3.21173817e-02, -1.62991345e-01, 1.22045174e-01], [-8.85733962e-03, -2.33069301e-01, 1.60033137e-01, -2.29237735e-01, -4.66225669e-03, 5.39076328e-02, 7.32713640e-02, 1.70185678e-02, -1.80564761e-01, 8.30461606e-02], [-4.99423817e-02, -2.57470608e-01, 8.87782201e-02, -7.79092163e-02, -5.64297512e-02, -4.15608212e-02, 6.76474050e-02, -6.45275861e-02, -5.89100383e-02, 1.33590668e-01], [-3.49371880e-03, -3.32547307e-01, 7.61690363e-02, -1.68725461e-01, -1.34751067e-01, 2.00760923e-02, 2.61783063e-01, -1.46969214e-01, -7.66006187e-02, 2.36703336e-01], [-6.80768266e-02, -1.99373543e-01, 1.93201065e-01, -1.17352977e-01, 7.12655187e-02, 1.16392002e-02, 9.39144790e-02, -1.38962641e-03, -1.02351576e-01, 1.08713463e-01], [-2.81550437e-02, -1.95229873e-01, -3.60348187e-02, -2.01792449e-01, -6.71518669e-02, 4.98618260e-02, 1.06975645e-01, 1.61965601e-02, -1.11789890e-02, 6.69299513e-02]], dtype=float32)>, 3: <tf.Tensor: shape=(64, 10), dtype=float32, numpy= array([[-1.23437926e-01, -2.11030200e-01, 7.07738250e-02, -7.69976526e-02, -6.43500537e-02, -6.28717169e-02, 2.99714357e-02, -3.28062996e-02, 5.16504608e-02, 9.42404345e-02], [-7.26455823e-04, -2.13334084e-01, 7.30197653e-02, -1.66986436e-01, -4.56670895e-02, 2.76470706e-02, 1.53899685e-01, -2.37623192e-02, -3.16666961e-02, 4.42121029e-02], [-2.13354826e-03, -1.75666809e-01, 8.51001665e-02, -1.42388791e-01, -5.21301404e-02, 7.05131218e-02, 1.22245640e-01, -4.14452441e-02, -1.23094648e-01, 7.01946989e-02], [-1.18392020e-01, -2.86092699e-01, 1.70314386e-01, -1.31102085e-01, -7.33606666e-02, 6.05323017e-02, 1.89448580e-01, -1.48474220e-02, -7.57791176e-02, 1.51703298e-01], [-6.50813356e-02, -2.07763746e-01, 1.45379066e-01, -2.61255175e-01, 9.87304375e-03, 1.09957360e-01, 5.93481660e-02, 8.72598365e-02, -1.12408638e-01, 6.00734800e-02], [-2.53683478e-02, -1.36544153e-01, 5.81072867e-02, -4.69100401e-02, 9.79136601e-02, 1.66488513e-02, 1.53901130e-02, -1.57943927e-03, -5.16440421e-02, 1.18039034e-01], [-7.32891113e-02, -1.10125624e-01, 8.57788101e-02, -1.77573532e-01, -2.87510268e-02, 2.23082006e-02, 2.15010419e-02, 1.23945773e-02, -3.64443287e-04, -1.41462758e-02], [-4.70349938e-02, -2.90533274e-01, 7.98520297e-02, -1.36518568e-01, -4.68233973e-02, 8.85699019e-02, 1.43734679e-01, -6.41912892e-02, -5.08216992e-02, 1.01841301e-01], [-7.54870772e-02, -3.13557804e-01, 2.67809272e-01, -2.40639359e-01, -1.99801490e-01, -9.13138837e-02, 1.73096776e-01, -1.36064738e-01, -3.37650627e-02, 1.98289171e-01], [-2.22930610e-01, -2.25699902e-01, 1.07240647e-01, -2.41287231e-01, -1.44734263e-01, 1.54250145e-01, 8.78515765e-02, -6.28602803e-02, 6.03416897e-02, 8.88438970e-02], [-1.74472794e-01, -1.96533009e-01, 3.55295300e-01, -2.14963138e-01, -1.08721137e-01, 2.95408070e-02, -1.09769732e-01, -1.76001355e-01, -6.58346936e-02, -8.53226408e-02], [ 1.57252755e-02, -1.63383633e-01, 1.76904202e-01, -1.07845865e-01, -2.97124833e-02, 1.56366885e-01, 2.95609683e-02, -6.65073544e-02, -3.96402329e-02, 7.56366700e-02], [-6.95464462e-02, -2.83027083e-01, 1.93344474e-01, -2.86879152e-01, -6.27231151e-02, -7.34102204e-02, 9.83634740e-02, -3.84070240e-02, -5.49911931e-02, 1.33385211e-01], [-1.49667367e-01, -2.56719440e-01, 2.14695662e-01, -3.08474004e-01, 2.87505314e-02, 4.68948632e-02, 1.04293235e-01, 3.12864967e-02, -1.40534639e-01, 1.16315052e-01], [-6.10959381e-02, -2.14545816e-01, 1.75310627e-01, -1.88737273e-01, -7.98486769e-02, -5.23116775e-02, -1.46004818e-02, -1.38967201e-01, -1.01485178e-01, 3.90701070e-02], [ 1.30137103e-02, -1.88789338e-01, 2.04270892e-02, -1.78047240e-01, 3.90435010e-02, -1.90784000e-02, 1.25754237e-01, 1.14320777e-02, -1.10013299e-01, 1.36420995e-01], [ 1.08336322e-02, -1.95143893e-01, 1.28844112e-01, -1.71142489e-01, -6.16367608e-02, 7.18701780e-02, 1.23423249e-01, 3.29518877e-03, -9.05838907e-02, 1.35054633e-01], [-1.80107653e-01, -2.09284469e-01, 9.61361304e-02, -1.92408979e-01, -5.47624007e-02, 8.05967860e-03, 7.47046322e-02, -5.30524775e-02, -1.24244705e-01, 1.27441645e-01], [-1.06447399e-01, -3.13466191e-01, 1.39964104e-01, -2.80033171e-01, -2.85412878e-01, -1.22623593e-02, 2.39364445e-01, -2.31475756e-01, -7.53569677e-02, 2.06900090e-01], [-2.76292115e-01, -3.10410082e-01, 2.84758747e-01, -2.45441675e-01, -1.41810790e-01, 4.31530774e-02, 2.99971104e-02, -1.24273337e-01, -3.62917148e-02, -4.92124781e-02], [-5.76281697e-02, -2.09266499e-01, 1.23424739e-01, -1.79520965e-01, -1.22136667e-01, -4.00786847e-03, 1.91921741e-02, -1.44317269e-01, 2.59571970e-02, 6.23718686e-02], [-3.93996537e-02, -1.67124778e-01, 2.38287672e-01, -1.69658273e-01, 4.96143475e-02, -1.43091679e-02, -2.71584019e-02, 4.25787643e-03, -2.49763280e-02, 6.06326014e-02], [ 1.70732245e-01, -2.62508571e-01, 2.17388391e-01, -2.07690760e-01, -1.08002909e-02, 6.53989986e-03, 1.25941932e-01, -9.77744311e-02, -2.31329128e-02, 1.13914981e-01], [ 1.60433128e-02, -2.09513843e-01, 6.75286651e-02, -1.37861341e-01, -2.99778134e-02, 4.19918112e-02, 1.72477245e-01, -9.76834632e-03, -4.02618200e-02, 7.57538527e-02], [ 1.90832317e-02, -1.72241166e-01, 3.16764116e-01, -3.23181212e-01, -2.82086506e-02, 1.27550080e-01, 4.65994626e-02, -3.42884026e-02, -8.70371461e-02, 8.29574317e-02], [-9.78137851e-02, -1.72557890e-01, 1.26011536e-01, -1.51672348e-01, -1.68658912e-01, 2.47031674e-02, 5.05703539e-02, -1.26446560e-01, -4.29677814e-02, 5.88577203e-02], [-1.74206756e-02, -2.31256858e-01, 1.43946901e-01, -2.33153120e-01, 6.21279329e-03, 1.50586590e-01, 1.40525028e-01, -1.69591848e-02, -1.13764554e-01, 6.77267537e-02], [-7.25764483e-02, -2.08863869e-01, 9.38807651e-02, -2.29979351e-01, -1.59057170e-01, -1.88987292e-02, 1.32216886e-01, -1.95539623e-01, 1.77561790e-02, 1.81802884e-01], [-1.35613859e-01, -2.62519002e-01, 1.21524513e-01, -1.50561064e-01, -1.61474675e-01, -1.13880709e-01, 1.49025053e-01, -1.55651301e-01, -6.48130104e-03, 1.20292723e-01], [-1.21942669e-01, -1.56770289e-01, 2.57681996e-01, -1.55968562e-01, -2.19095275e-01, 1.10958248e-01, 6.36297166e-02, -1.66549608e-01, -1.30507171e-01, 1.04893506e-01], [ 1.37873404e-02, -2.13068843e-01, 1.79353178e-01, -1.20212674e-01, 3.58381718e-02, -4.50651199e-02, 3.30163166e-02, -1.15171790e-01, -1.33268774e-01, 2.88895294e-02], [-1.28334463e-01, -2.76011050e-01, 2.17485175e-01, -2.99596727e-01, -1.37224853e-01, 1.62479579e-01, 2.09923983e-01, -8.74459073e-02, -1.08367696e-01, 9.83993635e-02], [-1.84781849e-01, -1.88956067e-01, 2.28338674e-01, -1.78052887e-01, -5.35424426e-02, 1.21365637e-02, -1.07846186e-02, -4.41186838e-02, -4.95132506e-02, 1.01655386e-02], [-2.42183685e-01, -1.97035596e-01, 2.55758226e-01, -2.43304133e-01, -1.50894076e-01, 1.05746508e-01, -3.57776880e-05, -2.20813155e-01, -6.98048174e-02, -2.82068141e-02], [-9.33281481e-02, -2.48022258e-01, 2.67458707e-01, -2.15162337e-01, 1.87640339e-02, 9.82057229e-02, -4.12685424e-03, 7.90382698e-02, 9.91706178e-02, 4.65318561e-02], [ 2.11611893e-02, -2.12271124e-01, 7.74582550e-02, -2.28891864e-01, -2.44507156e-02, 5.35808504e-02, 1.15697250e-01, -6.18771240e-02, -1.17006741e-01, 9.31429863e-02], [ 3.55179720e-02, -1.57204956e-01, 2.04363227e-01, -2.01112002e-01, -9.63690728e-02, 7.52578378e-02, 4.60278839e-02, -2.73984708e-02, -4.58910689e-02, 8.82677138e-02], [-7.28457123e-02, -2.30979905e-01, 2.58502275e-01, -2.02101186e-01, -1.26604959e-02, 1.85082227e-01, 9.88619924e-02, 5.87529838e-02, -4.16241214e-02, 4.52574901e-02], [-2.18311772e-02, -3.04886639e-01, 1.56199291e-01, -2.39684492e-01, -1.43737897e-01, 9.85218510e-02, 9.23808292e-02, -1.25980362e-01, -1.01783536e-01, 1.46639124e-02], [-1.69992715e-01, -3.35010260e-01, 2.39863455e-01, -4.33992386e-01, -1.44472539e-01, -4.53791097e-02, 7.22908154e-02, -1.00417703e-01, -7.51682967e-02, 1.34948671e-01], [-1.21986337e-01, -1.73863351e-01, 7.04779774e-02, -1.47365391e-01, 9.05528665e-04, 7.39799812e-02, 1.49948150e-01, -6.89401776e-02, -1.33584261e-01, 7.15210661e-02], [ 1.22973248e-02, -1.25874892e-01, 2.08711267e-01, -1.50454119e-01, 2.29926370e-02, -1.46374479e-03, 3.03103589e-02, -1.46747604e-02, -1.69257328e-01, 5.16710505e-02], [-5.05342633e-02, -1.68722659e-01, 8.36300477e-02, -1.74878657e-01, -6.31242394e-02, 6.49843365e-02, 1.13599911e-01, -2.68256534e-02, -1.36071458e-01, 9.77638662e-02], [-1.87490731e-01, -4.27204907e-01, 2.36114219e-01, -4.12069440e-01, -3.24858665e-01, 5.41152470e-02, 1.31420240e-01, -2.53119111e-01, -6.34656847e-02, 2.21875519e-01], [-1.95926160e-01, -2.64419079e-01, 1.29408777e-01, -2.43842915e-01, -3.02285075e-01, -7.66035989e-02, 3.37431729e-02, -7.02858940e-02, 7.21262842e-02, 7.15475082e-02], [ 7.96887279e-03, -2.41170034e-01, 2.30757207e-01, -2.48726189e-01, -4.18096259e-02, -2.91381776e-03, 4.60217074e-02, 2.49441639e-02, -8.14060718e-02, 5.01608625e-02], [ 4.88011539e-03, -2.66560793e-01, 2.54576653e-01, -3.27401936e-01, -3.45417373e-02, 1.24850854e-01, 1.21500097e-01, 1.21178739e-02, -1.16022766e-01, 5.99885806e-02], [-2.56981283e-01, -1.44573420e-01, 1.04371175e-01, -4.21217978e-01, -1.59503579e-01, 8.96546319e-02, 3.06324773e-02, -1.53720096e-01, 8.64168853e-02, 7.25626722e-02], [-2.04685777e-01, -3.00617278e-01, 1.50979608e-01, -2.39197582e-01, -1.74976379e-01, 1.05926178e-01, 1.34055391e-01, -5.36262542e-02, -3.25728133e-02, 2.01152653e-01], [-1.50551647e-01, -3.61040205e-01, 1.20477870e-01, -1.80131868e-01, -8.18004832e-02, -1.01963282e-02, 1.33314073e-01, -1.33230805e-01, -1.08623460e-01, 2.06317723e-01], [-2.30869636e-01, -2.42820457e-01, 2.41334766e-01, -2.14613199e-01, -7.08356574e-02, 9.27197039e-02, 4.01981547e-03, -1.17412582e-01, -4.79239784e-03, -7.72589743e-02], [-8.84397775e-02, -1.29078090e-01, 1.07740365e-01, -1.63140357e-01, -4.10152525e-02, 5.49239926e-02, 4.32266295e-02, -1.57965105e-02, -9.11063626e-02, 6.61480129e-02], [-6.96447417e-02, -2.34083712e-01, 1.06215045e-01, -2.20578864e-01, -1.90385461e-01, -4.27494422e-02, 2.48595417e-01, -2.44048268e-01, 1.48005933e-02, 1.93281606e-01], [-1.07321113e-01, -8.97436291e-02, 7.78742731e-02, -1.70489565e-01, -2.13539917e-02, 3.00860796e-02, 3.55869122e-02, -1.04250573e-02, -2.48552170e-02, -4.15959209e-03], [-1.08606644e-01, -1.71337932e-01, 4.59231138e-02, -1.43676192e-01, -1.82187557e-03, 1.26099452e-01, 9.20387208e-02, -4.27118205e-02, -1.31104678e-01, 1.09454885e-01], [ 3.43690068e-03, -1.72197163e-01, 2.18428791e-01, -2.35832453e-01, 5.10253012e-04, 8.90633315e-02, 1.08800516e-01, -5.93448579e-02, -1.23274222e-01, 1.01912856e-01], [ 5.06044403e-02, -1.32788047e-01, 1.30340025e-01, -1.84194073e-01, -1.05169535e-01, 4.61454317e-02, 1.22318804e-01, 5.99838533e-02, -8.22765380e-02, 1.24655224e-01], [-1.66389197e-02, -1.59686506e-01, 2.11619198e-01, 3.12151276e-02, 5.85951060e-02, -9.21339169e-02, 8.64084736e-02, -2.78207473e-02, -9.98123661e-02, 1.67366728e-01], [-1.96241885e-01, -1.59708947e-01, -1.02402270e-02, -2.04398319e-01, -1.56408072e-01, 1.53439976e-02, -5.43324277e-03, -7.84794912e-02, -7.53903389e-03, 1.29824445e-01], [-3.42884958e-02, -1.97664559e-01, 8.04989636e-02, -7.73391724e-02, -5.32783344e-02, -3.41160446e-02, 1.41383514e-01, -5.25050722e-02, -7.55820721e-02, 2.40247995e-01], [-2.44099833e-03, -2.29914010e-01, 2.82605529e-01, -2.79507041e-01, 8.37497786e-03, 1.21161163e-01, 1.53830886e-01, 1.00024045e-02, -2.20579281e-01, 6.34054318e-02], [-1.37659907e-01, -2.35353991e-01, 1.32480711e-01, -2.27872014e-01, 2.50082836e-02, 3.37644294e-02, 1.26229554e-01, -2.06383504e-03, -4.54712920e-02, 1.63978681e-01], [-1.21992707e-01, -2.60297537e-01, 2.51864269e-02, -1.19379327e-01, -6.49414659e-02, 1.57263018e-02, 1.56158909e-01, -5.03900386e-02, -7.10312426e-02, 1.60762340e-01], [-1.35451257e-01, -2.06156701e-01, 1.90098822e-01, -2.72125185e-01, -1.27937704e-01, 1.76560059e-02, 1.05175495e-01, -8.65008608e-02, -9.35260206e-03, 1.94259420e-01]], dtype=float32)> } } 2023-07-27 06:34:13.680190: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Calling the restored function is just a forward pass on the saved model (tf.keras.Model.predict
). What if you want to continue training the loaded function? Or what if you need to embed the loaded function into a bigger model? A common practice is to wrap this loaded object into a Keras layer to achieve this. Luckily, TF Hub has hub.KerasLayer
for this purpose, shown here:
import tensorflow_hub as hub
def build_model(loaded):
x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
# Wrap what's loaded to a KerasLayer
keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
model = tf.keras.Model(x, keras_layer)
return model
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
model = build_model(loaded)
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') 2023-07-27 06:34:14.424425: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:551] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. Epoch 1/2 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 235/235 [==============================] - 4s 7ms/step - loss: 0.3245 - sparse_categorical_accuracy: 0.9092 Epoch 2/2 235/235 [==============================] - 2s 7ms/step - loss: 0.0955 - sparse_categorical_accuracy: 0.9726
In the above example, Tensorflow Hub's hub.KerasLayer
wraps the result loaded back from tf.saved_model.load
into a Keras layer that is used to build another model. This is very useful for transfer learning.
Which API should I use?
For saving, if you are working with a Keras model, use the Keras Model.save
API unless you need the additional control allowed by the low-level API. If what you are saving is not a Keras model, then the lower-level API, tf.saved_model.save
, is your only choice.
For loading, your API choice depends on what you want to get from the model loading API. If you cannot (or do not want to) get a Keras model, then use tf.saved_model.load
. Otherwise, use tf.keras.models.load_model
. Note that you can get a Keras model back only if you saved a Keras model.
It is possible to mix and match the APIs. You can save a Keras model with Model.save
, and load a non-Keras model with the low-level API, tf.saved_model.load
.
model = get_model()
# Saving the model using Keras `Model.save`
model.save(saved_model_path)
another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using the lower-level API
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
Saving/Loading from a local device
When saving and loading from a local I/O device while training on remote devices—for example, when using a Cloud TPU—you must use the option experimental_io_device
in tf.saved_model.SaveOptions
and tf.saved_model.LoadOptions
to set the I/O device to localhost
. For example:
model = get_model()
# Saving the model to a path on localhost.
saved_model_path = '/tmp/tf_save'
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)
# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
Caveats
One special case is when you create Keras models in certain ways, and then save them before training. For example:
class SubclassedModel(tf.keras.Model):
"""Example model defined by subclassing `tf.keras.Model`."""
output_name = 'output_layer'
def __init__(self):
super(SubclassedModel, self).__init__()
self._dense_layer = tf.keras.layers.Dense(
5, dtype=tf.dtypes.float32, name=self.output_name)
def call(self, inputs):
return self._dense_layer(inputs)
my_model = SubclassedModel()
try:
my_model.save(saved_model_path)
except ValueError as e:
print(f'{type(e).__name__}: ', *e.args)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fb814393280>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fb814393280>, because it is not built. ValueError: Model <__main__.SubclassedModel object at 0x7fb814393280> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override `Model.call()`. To specify an input shape, either call `build(input_shape)` directly, or call the model on actual data using `Model()`, `Model.fit()`, or `Model.predict()`. If you have a custom training step, please make sure to invoke the forward pass in train step through `Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.
A SavedModel saves the tf.types.experimental.ConcreteFunction
objects generated when you trace a tf.function
(check When is a Function tracing? in the Introduction to graphs and tf.function guide to learn more). If you get a ValueError
like this it's because Model.save
was not able to find or create a traced ConcreteFunction
.
tf.saved_model.save(my_model, saved_model_path)
x = tf.saved_model.load(saved_model_path)
x.signatures
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fb814393280>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7fb814393280>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.src.layers.core.dense.Dense object at 0x7fb8143800a0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.src.layers.core.dense.Dense object at 0x7fb8143800a0>, because it is not built. INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets _SignatureMap({})
Usually the model's forward pass—the call
method—will be traced automatically when the model is called for the first time, often via the Keras Model.fit
method. A ConcreteFunction
can also be generated by the Keras Sequential and Functional APIs, if you set the input shape, for example, by making the first layer either a tf.keras.layers.InputLayer
or another layer type, and passing it the input_shape
keyword argument.
To verify if your model has any traced ConcreteFunction
s, check if Model.save_spec
is None
:
print(my_model.save_spec() is None)
True
Let's use tf.keras.Model.fit
to train the model, and notice that the save_spec
gets defined and model saving will work:
BATCH_SIZE_PER_REPLICA = 4
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync
dataset_size = 100
dataset = tf.data.Dataset.from_tensors(
(tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))
).repeat(dataset_size).batch(BATCH_SIZE)
my_model.compile(optimizer='adam', loss='mean_squared_error')
my_model.fit(dataset, epochs=2)
print(my_model.save_spec() is None)
my_model.save(saved_model_path)
Epoch 1/2 7/7 [==============================] - 1s 3ms/step - loss: 6.6204 Epoch 2/2 7/7 [==============================] - 0s 2ms/step - loss: 6.2650 False INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fb81439da30>, 140428588826032), {}). INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fb81439da30>, 140428588826032), {}). INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fb81433bf40>, 140428588823152), {}). INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fb81433bf40>, 140428588823152), {}). INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fb81439da30>, 140428588826032), {}). INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fb81439da30>, 140428588826032), {}). INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fb81433bf40>, 140428588823152), {}). INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7fb81433bf40>, 140428588823152), {}). INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets