Watch talks from the 2019 TensorFlow Dev Summit Watch now

Get started with TensorFlow 2.0 for experts

View on TensorFlow.org Run in Google Colab View source on GitHub

This is a Google Colaboratory notebook file. Python programs are run directly in the browser—a great way to learn and use TensorFlow. To run the Colab notebook:

  1. Connect to a Python runtime: At the top-right of the menu bar, select CONNECT.
  2. Run all the notebook code cells: Select Runtime > Run all.

For more examples and guides, see the TensorFlow tutorials.

To get started, import the TensorFlow library into your program:

from __future__ import absolute_import, division, print_function

!pip install -q tensorflow==2.0.0-alpha0
import tensorflow_datasets as tfds
import tensorflow as tf

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

Load and prepare the MNIST dataset. Convert the samples from integers to floating-point numbers:

dataset, info = tfds.load('mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = dataset['train'], dataset['test']
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/1 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/2 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/3 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]
Downloading / extracting dataset mnist (11.06 MiB) to /root/tensorflow_datasets/mnist/1.0.0...

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.51 url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.51 url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...:   0%|          | 0/1 [00:00<?, ? file/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.51 url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.51 url/s]
Dl Size...:   0%|          | 0/9 [00:00<?, ? MiB/s]

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.51 url/s]
Dl Size...:   0%|          | 0/9 [00:00<?, ? MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.00 url/s]
Dl Size...:   0%|          | 0/9 [00:00<?, ? MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.00 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.00 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Extraction completed...:  50%|█████     | 1/2 [00:00<00:00,  3.23 file/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.00 url/s]
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s]

Extraction completed...: 100%|██████████| 2/2 [00:00<00:00,  2.83 file/s]
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.00 url/s]
Dl Size...:  10%|█         | 1/10 [00:00<00:07,  1.14 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.00 url/s]
Dl Size...:  20%|██        | 2/10 [00:00<00:07,  1.14 MiB/s]

Extraction completed...: 100%|██████████| 2/2 [00:00<00:00,  2.83 file/s]
Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.00 url/s]
Dl Size...:  30%|███       | 3/10 [00:00<00:04,  1.59 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:01<00:00,  3.00 url/s]
Dl Size...:  40%|████      | 4/10 [00:01<00:03,  1.59 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:01<00:00,  3.00 url/s]
Dl Size...:  50%|█████     | 5/10 [00:01<00:03,  1.59 MiB/s]

Extraction completed...: 100%|██████████| 2/2 [00:01<00:00,  2.83 file/s]
Dl Completed...:  50%|█████     | 2/4 [00:01<00:00,  3.00 url/s]
Dl Size...:  60%|██████    | 6/10 [00:01<00:01,  2.21 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:01<00:00,  3.00 url/s]
Dl Size...:  70%|███████   | 7/10 [00:01<00:01,  2.21 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:01<00:00,  3.00 url/s]
Dl Size...:  80%|████████  | 8/10 [00:01<00:00,  2.21 MiB/s]

Extraction completed...: 100%|██████████| 2/2 [00:01<00:00,  2.83 file/s]
Dl Completed...:  50%|█████     | 2/4 [00:01<00:00,  3.00 url/s]
Dl Size...:  90%|█████████ | 9/10 [00:01<00:00,  3.06 MiB/s]

Dl Completed...:  50%|█████     | 2/4 [00:01<00:00,  3.00 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  3.06 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  2.55 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  3.06 MiB/s]

Dl Completed...:  75%|███████▌  | 3/4 [00:01<00:00,  2.55 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  3.06 MiB/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  2.55 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  3.06 MiB/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  2.55 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  3.06 MiB/s]

Extraction completed...:  50%|█████     | 2/4 [00:01<00:00,  2.83 file/s]

Dl Completed...: 100%|██████████| 4/4 [00:01<00:00,  2.55 url/s]
Dl Size...: 100%|██████████| 10/10 [00:01<00:00,  3.06 MiB/s]

Extraction completed...:  75%|███████▌  | 3/4 [00:01<00:00,  2.03 file/s]

Dl Completed...: 100%|██████████| 4/4 [00:02<00:00,  2.55 url/s]
Dl Size...: 100%|██████████| 10/10 [00:02<00:00,  3.06 MiB/s]

Extraction completed...: 100%|██████████| 4/4 [00:02<00:00,  2.12 file/s]
Dl Completed...: 100%|██████████| 4/4 [00:02<00:00,  1.97 url/s]
1 examples [00:00,  8.15 examples/s]




60000 examples [00:17, 3357.23 examples/s]
Shuffling...:   0%|          | 0/10 [00:00<?, ? shard/s]WARNING: Logging before flag parsing goes to stderr.
W0307 18:15:23.301660 140639362934528 deprecation.py:323] From /usr/local/lib/python3.5/dist-packages/tensorflow_datasets/core/file_format_adapter.py:249: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 211994.14 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 125843.59 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 206113.37 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  20%|██        | 2/10 [00:00<00:00, 10.83 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 189074.56 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 125249.84 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 193757.64 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  40%|████      | 4/10 [00:00<00:00, 10.74 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 193699.48 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 111517.03 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 232024.34 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  60%|██████    | 6/10 [00:00<00:00, 10.79 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 231780.73 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 125995.43 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 177310.27 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...:  80%|████████  | 8/10 [00:00<00:00, 10.87 shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 228632.65 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 111484.91 examples/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 6000 examples [00:00, 260856.02 examples/s]
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s]
Shuffling...: 100%|██████████| 10/10 [00:00<00:00, 10.85 shard/s]
10000 examples [00:02, 3364.76 examples/s]
Shuffling...:   0%|          | 0/1 [00:00<?, ? shard/s]
Reading...: 0 examples [00:00, ? examples/s]
Reading...: 10000 examples [00:00, 296580.73 examples/s]
Writing...:   0%|          | 0/10000 [00:00<?, ? examples/s]
Shuffling...: 100%|██████████| 1/1 [00:00<00:00,  8.21 shard/s]
def convert_types(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255
  return image, label
mnist_train = mnist_train.map(convert_types).shuffle(10000).batch(32)
mnist_test = mnist_test.map(convert_types).batch(32)

Build the tf.keras model using the Keras model subclassing API:

class MyModel(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.conv1 = Conv2D(32, 3, activation='relu')
    self.flatten = Flatten()
    self.d1 = Dense(128, activation='relu')
    self.d2 = Dense(10, activation='softmax')

  def call(self, x):
    x = self.conv1(x)
    x = self.flatten(x)
    x = self.d1(x)
    return self.d2(x)
  
model = MyModel()

Choose an optimizer and loss function for training:

loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

optimizer = tf.keras.optimizers.Adam()

Select metrics to measure the loss and the accuracy of the model. These metrics accumulate the values over epochs and then print the overall result.

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

Train the model using tf.GradientTape:

@tf.function
def train_step(image, label):
  with tf.GradientTape() as tape:
    predictions = model(image)
    loss = loss_object(label, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  
  train_loss(loss)
  train_accuracy(label, predictions)

Now test the model:

@tf.function
def test_step(image, label):
  predictions = model(image)
  t_loss = loss_object(label, predictions)
  
  test_loss(t_loss)
  test_accuracy(label, predictions)
EPOCHS = 5

for epoch in range(EPOCHS):
  for image, label in mnist_train:
    train_step(image, label)
  
  for test_image, test_label in mnist_test:
    test_step(test_image, test_label)
  
  template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
  print (template.format(epoch+1,
                         train_loss.result(), 
                         train_accuracy.result()*100,
                         test_loss.result(), 
                         test_accuracy.result()*100))
Epoch 1, Loss: 0.13152919709682465, Accuracy: 96.13166809082031, Test Loss: 0.058612946420907974, Test Accuracy: 98.16999816894531
Epoch 2, Loss: 0.08616659790277481, Accuracy: 97.43250274658203, Test Loss: 0.06092171370983124, Test Accuracy: 98.04999542236328
Epoch 3, Loss: 0.06458542495965958, Accuracy: 98.05055236816406, Test Loss: 0.06026700511574745, Test Accuracy: 98.13667297363281
Epoch 4, Loss: 0.05185882747173309, Accuracy: 98.42291259765625, Test Loss: 0.05965104699134827, Test Accuracy: 98.16999816894531
Epoch 5, Loss: 0.04336000606417656, Accuracy: 98.67133331298828, Test Loss: 0.05901845172047615, Test Accuracy: 98.22599792480469

The image classifier is now trained to ~98% accuracy on this dataset. To learn more, read the TensorFlow tutorials.