Train and serve a TensorFlow model with TensorFlow Serving

This guide trains a neural network model to classify images of clothing, like sneakers and shirts, saves the trained model, and then serves it with TensorFlow Serving. The focus is on TensorFlow Serving, rather than the modeling and training in TensorFlow, so for a complete example which focuses on the modeling and training see the Basic Classification example.

This guide uses tf.keras, a high-level API to build and train models in TensorFlow.

import sys

# Confirm that we're using Python 3
assert sys.version_info.major == 3, 'Oops, not running Python 3. Use Runtime > Change runtime type'
# TensorFlow and tf.keras
print("Installing dependencies for Colab environment")
!pip install -Uq grpcio==1.26.0

import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
import subprocess

print('TensorFlow version: {}'.format(tf.__version__))

Create your model

Import the Fashion MNIST dataset

This guide uses the Fashion MNIST dataset which contains 70,000 grayscale images in 10 categories. The images show individual articles of clothing at low resolution (28 by 28 pixels), as seen here:

Fashion MNIST sprite
Figure 1. Fashion-MNIST samples (by Zalando, MIT License).
 

Fashion MNIST is intended as a drop-in replacement for the classic MNIST dataset—often used as the "Hello, World" of machine learning programs for computer vision. You can access the Fashion MNIST directly from TensorFlow, just import and load the data.

fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# scale the values to 0.0 to 1.0
train_images = train_images / 255.0
test_images = test_images / 255.0

# reshape for feeding into the model
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

print('\ntrain_images.shape: {}, of {}'.format(train_images.shape, train_images.dtype))
print('test_images.shape: {}, of {}'.format(test_images.shape, test_images.dtype))
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
29515/29515 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26421880/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
5148/5148 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4422102/4422102 [==============================] - 0s 0us/step

train_images.shape: (60000, 28, 28, 1), of float64
test_images.shape: (10000, 28, 28, 1), of float64

Train and evaluate your model

Let's use the simplest possible CNN, since we're not focused on the modeling part.

model = keras.Sequential([
  keras.layers.Conv2D(input_shape=(28,28,1), filters=8, kernel_size=3, 
                      strides=2, activation='relu', name='Conv1'),
  keras.layers.Flatten(),
  keras.layers.Dense(10, name='Dense')
])
model.summary()

testing = False
epochs = 5

model.compile(optimizer='adam', 
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(train_images, train_labels, epochs=epochs)

test_loss, test_acc = model.evaluate(test_images, test_labels)
print('\nTest accuracy: {}'.format(test_acc))
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 Conv1 (Conv2D)              (None, 13, 13, 8)         80        
                                                                 
 flatten (Flatten)           (None, 1352)              0         
                                                                 
 Dense (Dense)               (None, 10)                13530     
                                                                 
=================================================================
Total params: 13,610
Trainable params: 13,610
Non-trainable params: 0
_________________________________________________________________
Epoch 1/5
1875/1875 [==============================] - 11s 3ms/step - loss: 0.5487 - sparse_categorical_accuracy: 0.8076
Epoch 2/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.4082 - sparse_categorical_accuracy: 0.8560
Epoch 3/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.3688 - sparse_categorical_accuracy: 0.8696
Epoch 4/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.3475 - sparse_categorical_accuracy: 0.8769
Epoch 5/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.3309 - sparse_categorical_accuracy: 0.8826
313/313 [==============================] - 1s 2ms/step - loss: 0.3662 - sparse_categorical_accuracy: 0.8713

Test accuracy: 0.8712999820709229

Save your model

To load our trained model into TensorFlow Serving we first need to save it in SavedModel format. This will create a protobuf file in a well-defined directory hierarchy, and will include a version number. TensorFlow Serving allows us to select which version of a model, or "servable" we want to use when we make inference requests. Each version will be exported to a different sub-directory under the given path.

# Fetch the Keras session and save the model
# The signature definition is defined by the input and output tensors,
# and stored with the default serving key
import tempfile

MODEL_DIR = tempfile.gettempdir()
version = 1
export_path = os.path.join(MODEL_DIR, str(version))
print('export_path = {}\n'.format(export_path))

tf.keras.models.save_model(
    model,
    export_path,
    overwrite=True,
    include_optimizer=True,
    save_format=None,
    signatures=None,
    options=None
)

print('\nSaved model:')
!ls -l {export_path}
export_path = /tmpfs/tmp/1
WARNING:absl:Function `_wrapped_model` contains input name(s) Conv1_input with unsupported characters which will be renamed to conv1_input in the SavedModel.
WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op while saving (showing 1 of 1). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmpfs/tmp/1/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/1/assets
Saved model:
total 112
drwxr-xr-x 2 kbuilder kbuilder  4096 Jul 28 11:22 assets
-rw-rw-r-- 1 kbuilder kbuilder    57 Jul 28 11:22 fingerprint.pb
-rw-rw-r-- 1 kbuilder kbuilder  8757 Jul 28 11:22 keras_metadata.pb
-rw-rw-r-- 1 kbuilder kbuilder 89140 Jul 28 11:22 saved_model.pb
drwxr-xr-x 2 kbuilder kbuilder  4096 Jul 28 11:22 variables

Examine your saved model

We'll use the command line utility saved_model_cli to look at the MetaGraphDefs (the models) and SignatureDefs (the methods you can call) in our SavedModel. See this discussion of the SavedModel CLI in the TensorFlow Guide.

saved_model_cli show --dir {export_path} --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is: 

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['Conv1_input'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 28, 28, 1)
        name: serving_default_Conv1_input:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['Dense'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 10)
        name: StatefulPartitionedCall:0
  Method name is: tensorflow/serving/predict
The MetaGraph with tag set ['serve'] contains the following ops: {'NoOp', 'MatMul', 'SaveV2', 'MergeV2Checkpoints', 'RestoreV2', 'StaticRegexFullMatch', 'Relu', 'ShardedFilename', 'ReadVariableOp', 'StringJoin', 'Identity', 'Const', 'Reshape', 'Select', 'VarHandleOp', 'StatefulPartitionedCall', 'DisableCopyOnRead', 'AssignVariableOp', 'BiasAdd', 'Placeholder', 'Conv2D', 'Pack'}
2023-07-28 11:22:36.764176: F tensorflow/tsl/platform/statusor.cc:33] Attempting to fetch value instead of handling error INTERNAL: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 17066885120
Fatal Python error: Aborted

Current thread 0x00007fbc7ec07740 (most recent call first):
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/context.py", line 583 in ensure_initialized
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/context.py", line 1347 in is_custom_device
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/eager/context.py", line 2745 in is_custom_device
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/saving/saveable_object_util.py", line 68 in set_cpu0
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/checkpoint/functional_saver.py", line 238 in __init__
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/checkpoint/functional_saver.py", line 265 in from_saveables
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/checkpoint/checkpoint.py", line 357 in restore_saveables
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/checkpoint/restore.py", line 468 in _restore_descendants
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/checkpoint/restore.py", line 61 in restore
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/checkpoint/checkpoint.py", line 1451 in restore
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/load.py", line 530 in _restore_checkpoint
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/load.py", line 195 in __init__
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/load.py", line 966 in load_partial
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/saved_model/load.py", line 836 in load
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py", line 383 in _show_defined_functions
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py", line 506 in _show_all
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py", line 943 in show
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py", line 1282 in smcli_main
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/absl/app.py", line 254 in _run_main
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/absl/app.py", line 308 in run
  File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/tools/saved_model_cli.py", line 1284 in main
  File "/tmpfs/src/tf_docs_env/bin/saved_model_cli", line 8 in <module>

That tells us a lot about our model! In this case we just trained our model, so we already know the inputs and outputs, but if we didn't this would be important information. It doesn't tell us everything, like the fact that this is grayscale image data for example, but it's a great start.

Serve your model with TensorFlow Serving

Add TensorFlow Serving distribution URI as a package source:

We're preparing to install TensorFlow Serving using Aptitude since this Colab runs in a Debian environment. We'll add the tensorflow-model-server package to the list of packages that Aptitude knows about. Note that we're running as root.

import sys
# We need sudo prefix if not on a Google Colab.
if 'google.colab' not in sys.modules:
  SUDO_IF_NEEDED = 'sudo'
else:
  SUDO_IF_NEEDED = ''
# This is the same as you would do from your command line, but without the [arch=amd64], and no sudo
# You would instead do:
# echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list && \
# curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -

!echo "deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | {SUDO_IF_NEEDED} tee /etc/apt/sources.list.d/tensorflow-serving.list && \
curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | {SUDO_IF_NEEDED} apt-key add -
!{SUDO_IF_NEEDED} apt update
deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  2943  100  2943    0     0  39240      0 --:--:-- --:--:-- --:--:-- 39240
OK
Hit:1 http://us-west1.gce.archive.ubuntu.com/ubuntu focal InRelease
Hit:2 http://us-west1.gce.archive.ubuntu.com/ubuntu focal-updates InRelease
Get:3 http://us-west1.gce.archive.ubuntu.com/ubuntu focal-backports InRelease [108 kB]
Hit:4 https://download.docker.com/linux/ubuntu focal InRelease
Hit:5 https://nvidia.github.io/libnvidia-container/stable/ubuntu18.04/amd64  InRelease
Hit:6 https://nvidia.github.io/nvidia-container-runtime/stable/ubuntu18.04/amd64  InRelease
Hit:7 https://nvidia.github.io/nvidia-docker/ubuntu18.04/amd64  InRelease
Get:8 http://storage.googleapis.com/tensorflow-serving-apt stable InRelease [3026 B]
Hit:9 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64  InRelease
Hit:10 http://security.ubuntu.com/ubuntu focal-security InRelease
Hit:12 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu focal InRelease
Hit:11 https://apt.llvm.org/focal llvm-toolchain-focal-16 InRelease
Hit:13 http://ppa.launchpad.net/longsleep/golang-backports/ubuntu focal InRelease
Hit:14 http://ppa.launchpad.net/openjdk-r/ppa/ubuntu focal InRelease
Err:8 http://storage.googleapis.com/tensorflow-serving-apt stable InRelease
  The following signatures were invalid: EXPKEYSIG 544B7F63BF9E4D5F Tensorflow Serving Developer (Tensorflow Serving APT repository key) <tensorflow-serving-dev@googlegroups.com>

W: GPG error: http://storage.googleapis.com/tensorflow-serving-apt stable InRelease: The following signatures were invalid: EXPKEYSIG 544B7F63BF9E4D5F Tensorflow Serving Developer (Tensorflow Serving APT repository key) <tensorflow-serving-dev@googlegroups.com>
E: The repository 'http://storage.googleapis.com/tensorflow-serving-apt stable InRelease' is not signed.
N: Updating from such a repository can't be done securely, and is therefore disabled by default.
N: See apt-secure(8) manpage for repository creation and user configuration details.

Install TensorFlow Serving

This is all you need - one command line!

# TODO: Use the latest model server version when colab supports it.
#!{SUDO_IF_NEEDED} apt-get install tensorflow-model-server
# We need to install Tensorflow Model server 2.8 instead of latest version
# Tensorflow Serving >2.9.0 required `GLIBC_2.29` and `GLIBCXX_3.4.26`. Currently colab environment doesn't support latest version of`GLIBC`,so workaround is to use specific version of Tensorflow Serving `2.8.0` to mitigate issue.
wget 'http://storage.googleapis.com/tensorflow-serving-apt/pool/tensorflow-model-server-2.8.0/t/tensorflow-model-server/tensorflow-model-server_2.8.0_all.deb'
dpkg -i tensorflow-model-server_2.8.0_all.deb
pip3 install tensorflow-serving-api==2.8.0
--2023-07-28 11:22:42--  http://storage.googleapis.com/tensorflow-serving-apt/pool/tensorflow-model-server-2.8.0/t/tensorflow-model-server/tensorflow-model-server_2.8.0_all.deb
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.199.128, 173.194.203.128, 173.194.202.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.199.128|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 340152790 (324M) [application/x-debian-package]
Saving to: ‘tensorflow-model-server_2.8.0_all.deb’

tensorflow-model-se 100%[===================>] 324.39M  42.5MB/s    in 9.1s    

2023-07-28 11:22:51 (35.7 MB/s) - ‘tensorflow-model-server_2.8.0_all.deb’ saved [340152790/340152790]

dpkg: error: requested operation requires superuser privilege
Collecting tensorflow-serving-api==2.8.0
  Downloading tensorflow_serving_api-2.8.0-py2.py3-none-any.whl (37 kB)
Requirement already satisfied: grpcio<2,>=1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-serving-api==2.8.0) (1.26.0)
Requirement already satisfied: protobuf>=3.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-serving-api==2.8.0) (3.20.3)
Requirement already satisfied: tensorflow<3,>=2.8.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow-serving-api==2.8.0) (2.12.1)
Requirement already satisfied: six>=1.5.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from grpcio<2,>=1.0->tensorflow-serving-api==2.8.0) (1.16.0)
Requirement already satisfied: absl-py>=1.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (1.4.0)
Requirement already satisfied: astunparse>=1.6.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (1.6.3)
Requirement already satisfied: flatbuffers>=2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (23.5.26)
Requirement already satisfied: gast<=0.4.0,>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.4.0)
Requirement already satisfied: google-pasta>=0.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.2.0)
Requirement already satisfied: h5py>=2.9.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.9.0)
Requirement already satisfied: jax>=0.3.15 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.4.14)
Requirement already satisfied: keras<2.13,>=2.12.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.12.0)
Requirement already satisfied: libclang>=13.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (16.0.6)
Requirement already satisfied: numpy<=1.24.3,>=1.22 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (1.22.4)
Requirement already satisfied: opt-einsum>=2.3.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.3.0)
Requirement already satisfied: packaging in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (20.9)
Requirement already satisfied: setuptools in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (68.0.0)
Requirement already satisfied: tensorboard<2.13,>=2.12 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.12.3)
Requirement already satisfied: tensorflow-estimator<2.13,>=2.12.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.12.0)
Requirement already satisfied: termcolor>=1.1.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.3.0)
Requirement already satisfied: typing-extensions<4.6.0,>=3.6.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (4.5.0)
Requirement already satisfied: wrapt<1.15,>=1.11.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (1.14.1)
Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.32.0)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from astunparse>=1.6.0->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.41.0)
Requirement already satisfied: ml-dtypes>=0.2.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jax>=0.3.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.2.0)
Requirement already satisfied: scipy>=1.7 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jax>=0.3.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (1.11.1)
Requirement already satisfied: importlib-metadata>=4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from jax>=0.3.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (6.8.0)
Collecting grpcio<2,>=1.0 (from tensorflow-serving-api==2.8.0)
  Obtaining dependency information for grpcio<2,>=1.0 from https://files.pythonhosted.org/packages/f5/f6/57fbd39af17aaae321109411ef2faf121768473ebc1bbf3694b06d3282c8/grpcio-1.56.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Downloading grpcio-1.56.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.0 kB)
Requirement already satisfied: google-auth<3,>=1.6.3 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (1.35.0)
Requirement already satisfied: google-auth-oauthlib<1.1,>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.5.3)
Requirement already satisfied: markdown>=2.6.8 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.4.4)
Requirement already satisfied: requests<3,>=2.21.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.31.0)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.7.1)
Requirement already satisfied: werkzeug>=1.0.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.3.6)
Requirement already satisfied: pyparsing>=2.0.2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from packaging->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.1.0)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (4.2.4)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.3.0)
Requirement already satisfied: rsa<5,>=3.1.4 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (4.9)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from google-auth-oauthlib<1.1,>=0.5->tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (1.3.1)
Requirement already satisfied: zipp>=0.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from importlib-metadata>=4.6->jax>=0.3.15->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.16.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.2.0)
Requirement already satisfied: idna<4,>=2.5 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.4)
Requirement already satisfied: urllib3<3,>=1.21.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests<3,>=2.21.0->tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2023.7.22)
Requirement already satisfied: MarkupSafe>=2.1.1 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from werkzeug>=1.0.1->tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (2.1.3)
Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (0.5.0)
Requirement already satisfied: oauthlib>=3.0.0 in /tmpfs/src/tf_docs_env/lib/python3.9/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<1.1,>=0.5->tensorboard<2.13,>=2.12->tensorflow<3,>=2.8.0->tensorflow-serving-api==2.8.0) (3.2.2)
Downloading grpcio-1.56.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.2 MB)
Installing collected packages: grpcio, tensorflow-serving-api
  Attempting uninstall: grpcio
    Found existing installation: grpcio 1.26.0
    Uninstalling grpcio-1.26.0:
      Successfully uninstalled grpcio-1.26.0
  Attempting uninstall: tensorflow-serving-api
    Found existing installation: tensorflow-serving-api 2.12.2
    Uninstalling tensorflow-serving-api-2.12.2:
      Successfully uninstalled tensorflow-serving-api-2.12.2
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tfx 1.13.0 requires tensorflow-serving-api!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,<3,>=1.15, but you have tensorflow-serving-api 2.8.0 which is incompatible.
tfx-bsl 1.13.0 requires tensorflow-serving-api!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,<3,>=1.15, but you have tensorflow-serving-api 2.8.0 which is incompatible.
Successfully installed grpcio-1.56.2 tensorflow-serving-api-2.8.0

Start running TensorFlow Serving

This is where we start running TensorFlow Serving and load our model. After it loads we can start making inference requests using REST. There are some important parameters:

  • rest_api_port: The port that you'll use for REST requests.
  • model_name: You'll use this in the URL of REST requests. It can be anything.
  • model_base_path: This is the path to the directory where you've saved your model.
os.environ["MODEL_DIR"] = MODEL_DIR
nohup tensorflow_model_server \
  --rest_api_port=8501 \
  --model_name=fashion_model \
  --model_base_path="${MODEL_DIR}" >server.log 2>&1
tail server.log
nohup: failed to run command 'tensorflow_model_server': No such file or directory

Make a request to your model in TensorFlow Serving

First, let's take a look at a random example from our test data.

def show(idx, title):
  plt.figure()
  plt.imshow(test_images[idx].reshape(28,28))
  plt.axis('off')
  plt.title('\n\n{}'.format(title), fontdict={'size': 16})

import random
rando = random.randint(0,len(test_images)-1)
show(rando, 'An Example Image: {}'.format(class_names[test_labels[rando]]))

png

Ok, that looks interesting. How hard is that for you to recognize? Now let's create the JSON object for a batch of three inference requests, and see how well our model recognizes things:

import json
data = json.dumps({"signature_name": "serving_default", "instances": test_images[0:3].tolist()})
print('Data: {} ... {}'.format(data[:50], data[len(data)-52:]))
Data: {"signature_name": "serving_default", "instances": ...  [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]]]}

Make REST requests

Newest version of the servable

We'll send a predict request as a POST to our server's REST endpoint, and pass it three examples. We'll ask our server to give us the latest version of our servable by not specifying a particular version.

# docs_infra: no_execute
!pip install -q requests

import requests
headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8501/v1/models/fashion_model:predict', data=data, headers=headers)
predictions = json.loads(json_response.text)['predictions']

show(0, 'The model thought this was a {} (class {}), and it was actually a {} (class {})'.format(
  class_names[np.argmax(predictions[0])], np.argmax(predictions[0]), class_names[test_labels[0]], test_labels[0]))

A particular version of the servable

Now let's specify a particular version of our servable. Since we only have one, let's select version 1. We'll also look at all three results.

# docs_infra: no_execute
headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8501/v1/models/fashion_model/versions/1:predict', data=data, headers=headers)
predictions = json.loads(json_response.text)['predictions']

for i in range(0,3):
  show(i, 'The model thought this was a {} (class {}), and it was actually a {} (class {})'.format(
    class_names[np.argmax(predictions[i])], np.argmax(predictions[i]), class_names[test_labels[i]], test_labels[i]))