View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
This tutorial shows you how to solve the Iris classification problem in TensorFlow using Estimators. An Estimator is a legacy TensorFlow high-level representation of a complete model. For more details see Estimators.
First things first
In order to get started, you will first import TensorFlow and a number of libraries you will need.
import tensorflow as tf
import pandas as pd
2024-01-24 02:24:58.204807: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-01-24 02:24:58.204853: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-01-24 02:24:58.206530: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
The data set
The sample program in this document builds and tests a model that classifies Iris flowers into three different species based on the size of their sepals and petals.
You will train a model using the Iris data set. The Iris data set contains four features and one label. The four features identify the following botanical characteristics of individual Iris flowers:
- sepal length
- sepal width
- petal length
- petal width
Based on this information, you can define a few helpful constants for parsing the data:
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']
Next, download and parse the Iris data set using Keras and Pandas. Note that you keep distinct datasets for training and testing.
train_path = tf.keras.utils.get_file(
"iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
test_path = tf.keras.utils.get_file(
"iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")
train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv 2194/2194 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv 573/573 [==============================] - 0s 0us/step
You can inspect your data to see that you have four float feature columns and one int32 label.
train.head()
For each of the datasets, split out the labels, which the model will be trained to predict.
train_y = train.pop('Species')
test_y = test.pop('Species')
# The label column has now been removed from the features.
train.head()
Overview of programming with Estimators
Now that you have the data set up, you can define a model using a TensorFlow Estimator. An Estimator is any class derived from tf.estimator.Estimator
. TensorFlow
provides a collection of
tf.estimator
(for example, LinearRegressor
) to implement common ML algorithms. Beyond
those, you may write your own
custom Estimators.
It is recommended using pre-made Estimators when just getting started.
To write a TensorFlow program based on pre-made Estimators, you must perform the following tasks:
- Create one or more input functions.
- Define the model's feature columns.
- Instantiate an Estimator, specifying the feature columns and various hyperparameters.
- Call one or more methods on the Estimator object, passing the appropriate input function as the source of the data.
Let's see how those tasks are implemented for Iris classification.
Create input functions
You must create input functions to supply data for training, evaluating, and prediction.
An input function is a function that returns a tf.data.Dataset
object
which outputs the following two-element tuple:
features
- A Python dictionary in which:- Each key is the name of a feature.
- Each value is an array containing all of that feature's values.
label
- An array containing the values of the label for every example.
Just to demonstrate the format of the input function, here's a simple implementation:
def input_evaluation_set():
features = {'SepalLength': np.array([6.4, 5.0]),
'SepalWidth': np.array([2.8, 2.3]),
'PetalLength': np.array([5.6, 3.3]),
'PetalWidth': np.array([2.2, 1.0])}
labels = np.array([2, 1])
return features, labels
Your input function may generate the features
dictionary and label
list any
way you like. However, It is recommended using TensorFlow's Dataset API, which can
parse all sorts of data.
The Dataset API can handle a lot of common cases for you. For example, using the Dataset API, you can easily read in records from a large collection of files in parallel and join them into a single stream.
To keep things simple in this example you are going to load the data with pandas, and build an input pipeline from this in-memory data:
def input_fn(features, labels, training=True, batch_size=256):
"""An input function for training or evaluating"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle and repeat if you are in training mode.
if training:
dataset = dataset.shuffle(1000).repeat()
return dataset.batch(batch_size)
Define the feature columns
A feature column
is an object describing how the model should use raw input data from the
features dictionary. When you build an Estimator model, you pass it a list of
feature columns that describes each of the features you want the model to use.
The tf.feature_column
module provides many options for representing data
to the model.
For Iris, the 4 raw features are numeric values, so you'll build a list of feature columns to tell the Estimator model to represent each of the four features as 32-bit floating-point values. Therefore, the code to create the feature column is:
# Feature columns describe how to use the input.
my_feature_columns = []
for key in train.keys():
my_feature_columns.append(tf.feature_column.numeric_column(key=key))
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_15149/1593920324.py:4: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version. Instructions for updating: Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
Feature columns can be far more sophisticated than those shown here. You can read more about Feature Columns in this guide.
Now that you have the description of how you want the model to represent the raw features, you can build the estimator.
Instantiate an estimator
The Iris problem is a classic classification problem. Fortunately, TensorFlow provides several pre-made classifier Estimators, including:
tf.estimator.DNNClassifier
for deep models that perform multi-class classification.tf.estimator.DNNLinearCombinedClassifier
for wide & deep models.tf.estimator.LinearClassifier
for classifiers based on linear models.
For the Iris problem, tf.estimator.DNNClassifier
seems like the best choice.
Here's how you instantiated this Estimator:
# Build a DNN with 2 hidden layers with 30 and 10 hidden nodes each.
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
# Two hidden layers of 30 and 10 nodes respectively.
hidden_units=[30, 10],
# The model must choose between 3 classes.
n_classes=3)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_15149/2221267581.py:2: DNNClassifierV2.__init__ (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/head_utils.py:59: MultiClassHead.__init__ (from tensorflow_estimator.python.estimator.head.multi_class_head) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:759: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpajiu4kwf INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpajiu4kwf', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
Train, Evaluate, and Predict
Now that you have an Estimator object, you can call methods to do the following:
- Train the model.
- Evaluate the trained model.
- Use the trained model to make predictions.
Train the model
Train the model by calling the Estimator's train
method as follows:
# Train the Model.
classifier.train(
input_fn=lambda: input_fn(train, train_y, training=True),
steps=5000)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/legacy/adagrad.py:93: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. 2024-01-24 02:25:03.361658: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT64 } } } is neither a subtype nor a supertype of the combined inputs preceding it: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT32 } } } for Tuple type infernce function 0 while inferring type of node 'dnn/zero_fraction/cond/output/_18' INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpajiu4kwf/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:loss = 1.6850352, step = 0 INFO:tensorflow:global_step/sec: 462.8 INFO:tensorflow:loss = 1.150296, step = 100 (0.217 sec) INFO:tensorflow:global_step/sec: 626.94 INFO:tensorflow:loss = 0.9824025, step = 200 (0.159 sec) INFO:tensorflow:global_step/sec: 612.044 INFO:tensorflow:loss = 0.9091196, step = 300 (0.163 sec) INFO:tensorflow:global_step/sec: 582.251 INFO:tensorflow:loss = 0.86559844, step = 400 (0.172 sec) INFO:tensorflow:global_step/sec: 589.565 INFO:tensorflow:loss = 0.8193757, step = 500 (0.170 sec) INFO:tensorflow:global_step/sec: 589.871 INFO:tensorflow:loss = 0.7948306, step = 600 (0.169 sec) INFO:tensorflow:global_step/sec: 575.219 INFO:tensorflow:loss = 0.7530799, step = 700 (0.174 sec) INFO:tensorflow:global_step/sec: 600.907 INFO:tensorflow:loss = 0.7460351, step = 800 (0.166 sec) INFO:tensorflow:global_step/sec: 583.448 INFO:tensorflow:loss = 0.7289372, step = 900 (0.171 sec) INFO:tensorflow:global_step/sec: 566.86 INFO:tensorflow:loss = 0.71678615, step = 1000 (0.176 sec) INFO:tensorflow:global_step/sec: 591.695 INFO:tensorflow:loss = 0.69973415, step = 1100 (0.169 sec) INFO:tensorflow:global_step/sec: 597.959 INFO:tensorflow:loss = 0.68764853, step = 1200 (0.167 sec) INFO:tensorflow:global_step/sec: 584.715 INFO:tensorflow:loss = 0.67702544, step = 1300 (0.171 sec) INFO:tensorflow:global_step/sec: 587.621 INFO:tensorflow:loss = 0.67019963, step = 1400 (0.170 sec) INFO:tensorflow:global_step/sec: 594.554 INFO:tensorflow:loss = 0.6395784, step = 1500 (0.168 sec) INFO:tensorflow:global_step/sec: 606.249 INFO:tensorflow:loss = 0.62857974, step = 1600 (0.165 sec) INFO:tensorflow:global_step/sec: 606.284 INFO:tensorflow:loss = 0.6217115, step = 1700 (0.165 sec) INFO:tensorflow:global_step/sec: 600.67 INFO:tensorflow:loss = 0.6134416, step = 1800 (0.166 sec) INFO:tensorflow:global_step/sec: 604.864 INFO:tensorflow:loss = 0.60754275, step = 1900 (0.165 sec) INFO:tensorflow:global_step/sec: 600.202 INFO:tensorflow:loss = 0.59899795, step = 2000 (0.167 sec) INFO:tensorflow:global_step/sec: 592.75 INFO:tensorflow:loss = 0.5824572, step = 2100 (0.169 sec) INFO:tensorflow:global_step/sec: 588.969 INFO:tensorflow:loss = 0.57866186, step = 2200 (0.170 sec) INFO:tensorflow:global_step/sec: 569.209 INFO:tensorflow:loss = 0.5617752, step = 2300 (0.176 sec) INFO:tensorflow:global_step/sec: 586.549 INFO:tensorflow:loss = 0.56995666, step = 2400 (0.170 sec) INFO:tensorflow:global_step/sec: 605.037 INFO:tensorflow:loss = 0.5625053, step = 2500 (0.165 sec) INFO:tensorflow:global_step/sec: 587.515 INFO:tensorflow:loss = 0.5609164, step = 2600 (0.170 sec) INFO:tensorflow:global_step/sec: 588.632 INFO:tensorflow:loss = 0.547412, step = 2700 (0.170 sec) INFO:tensorflow:global_step/sec: 579.686 INFO:tensorflow:loss = 0.541837, step = 2800 (0.173 sec) INFO:tensorflow:global_step/sec: 582.099 INFO:tensorflow:loss = 0.5386783, step = 2900 (0.172 sec) INFO:tensorflow:global_step/sec: 578.724 INFO:tensorflow:loss = 0.5345396, step = 3000 (0.173 sec) INFO:tensorflow:global_step/sec: 600.253 INFO:tensorflow:loss = 0.54028034, step = 3100 (0.167 sec) INFO:tensorflow:global_step/sec: 615.103 INFO:tensorflow:loss = 0.5272232, step = 3200 (0.163 sec) INFO:tensorflow:global_step/sec: 595.225 INFO:tensorflow:loss = 0.5215688, step = 3300 (0.168 sec) INFO:tensorflow:global_step/sec: 616.737 INFO:tensorflow:loss = 0.5193449, step = 3400 (0.162 sec) INFO:tensorflow:global_step/sec: 603.619 INFO:tensorflow:loss = 0.5223179, step = 3500 (0.166 sec) INFO:tensorflow:global_step/sec: 591.198 INFO:tensorflow:loss = 0.5109207, step = 3600 (0.169 sec) INFO:tensorflow:global_step/sec: 606.459 INFO:tensorflow:loss = 0.5095812, step = 3700 (0.165 sec) INFO:tensorflow:global_step/sec: 602.062 INFO:tensorflow:loss = 0.49649355, step = 3800 (0.166 sec) INFO:tensorflow:global_step/sec: 628.925 INFO:tensorflow:loss = 0.49958307, step = 3900 (0.159 sec) INFO:tensorflow:global_step/sec: 640.867 INFO:tensorflow:loss = 0.49448222, step = 4000 (0.156 sec) INFO:tensorflow:global_step/sec: 639.517 INFO:tensorflow:loss = 0.49083585, step = 4100 (0.156 sec) INFO:tensorflow:global_step/sec: 637.107 INFO:tensorflow:loss = 0.49493772, step = 4200 (0.157 sec) INFO:tensorflow:global_step/sec: 625.143 INFO:tensorflow:loss = 0.48450252, step = 4300 (0.160 sec) INFO:tensorflow:global_step/sec: 621.536 INFO:tensorflow:loss = 0.4757119, step = 4400 (0.161 sec) INFO:tensorflow:global_step/sec: 630.12 INFO:tensorflow:loss = 0.48683313, step = 4500 (0.159 sec) INFO:tensorflow:global_step/sec: 632.457 INFO:tensorflow:loss = 0.46958598, step = 4600 (0.158 sec) INFO:tensorflow:global_step/sec: 635.481 INFO:tensorflow:loss = 0.4842799, step = 4700 (0.157 sec) INFO:tensorflow:global_step/sec: 633.132 INFO:tensorflow:loss = 0.4761522, step = 4800 (0.158 sec) INFO:tensorflow:global_step/sec: 628.004 INFO:tensorflow:loss = 0.4783261, step = 4900 (0.159 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5000... INFO:tensorflow:Saving checkpoints for 5000 into /tmpfs/tmp/tmpajiu4kwf/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5000... INFO:tensorflow:Loss for final step: 0.4764105. <tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x7f8c709fe880>
Note that you wrap up your input_fn
call in a
lambda
to capture the arguments while providing an input function that takes no
arguments, as expected by the Estimator. The steps
argument tells the method
to stop training after a number of training steps.
Evaluate the trained model
Now that the model has been trained, you can get some statistics on its performance. The following code block evaluates the accuracy of the trained model on the test data:
eval_result = classifier.evaluate(
input_fn=lambda: input_fn(test, test_y, training=False))
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2024-01-24T02:25:12 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpajiu4kwf/model.ckpt-5000 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Inference Time : 0.73058s INFO:tensorflow:Finished evaluation at 2024-01-24-02:25:13 INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.53333336, average_loss = 0.559994, global_step = 5000, loss = 0.559994 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmpfs/tmp/tmpajiu4kwf/model.ckpt-5000 Test set accuracy: 0.533
Unlike the call to the train
method, you did not pass the steps
argument to evaluate. The input_fn
for eval only yields a single
epoch of data.
The eval_result
dictionary also contains the average_loss
(mean loss per sample), the loss
(mean loss per mini-batch) and the value of the estimator's global_step
(the number of training iterations it underwent).
Making predictions (inferring) from the trained model
You now have a trained model that produces good evaluation results. You can now use the trained model to predict the species of an Iris flower based on some unlabeled measurements. As with training and evaluation, you make predictions using a single function call:
# Generate predictions from the model
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
'SepalLength': [5.1, 5.9, 6.9],
'SepalWidth': [3.3, 3.0, 3.1],
'PetalLength': [1.7, 4.2, 5.4],
'PetalWidth': [0.5, 1.5, 2.1],
}
def input_fn(features, batch_size=256):
"""An input function for prediction."""
# Convert the inputs to a Dataset without labels.
return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)
predictions = classifier.predict(
input_fn=lambda: input_fn(predict_x))
The predict
method returns a Python iterable, yielding a dictionary of
prediction results for each example. The following code prints a few
predictions and their probabilities:
for pred_dict, expec in zip(predictions, expected):
class_id = pred_dict['class_ids'][0]
probability = pred_dict['probabilities'][class_id]
print('Prediction is "{}" ({:.1f}%), expected "{}"'.format(
SPECIES[class_id], 100 * probability, expec))
INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/base_head.py:786: ClassificationOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/multi_class_head.py:455: PredictOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpajiu4kwf/model.ckpt-5000 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. Prediction is "Setosa" (78.7%), expected "Setosa" Prediction is "Virginica" (49.3%), expected "Versicolor" Prediction is "Virginica" (63.2%), expected "Virginica"