Using TFRecords and tf.Example

View on TensorFlow.org Run in Google Colab View source on GitHub

To read data efficiently it can be helpful to serialize your data and store it in a set of files (100-200MB each) that can each be read linearly. This is especially true if the data is being streamed over a network. This can also be useful for caching any data-preprocessing.

The TFRecord format is a simple format for storing a sequence of binary records.

Protocol buffers are a cross-platform, cross-language library for efficient serialization of structured data.

Protocol messages are defined by .proto files, these are often the easiest way to understand a message type.

The tf.Example message (or protobuf) is a flexible message type that represents a {"string": value} mapping. It is designed for use with TensorFlow and is used throughout the higher-level APIs such as TFX.

This notebook will demonstrate how to create, parse, and use the tf.Example message, and then serialize, write, and read tf.Example messages to and from .tfrecord files.

Setup

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
tf.enable_eager_execution()

import numpy as np
import IPython.display as display

tf.Example

Data types for tf.Example

Fundamentally a tf.Example is a {"string": tf.train.Feature} mapping.

The tf.train.Feature message type can accept one of the following three types (See the .proto file. Most other generic types can be coerced into one of these.

  1. tf.train.BytesList (the following types can be coerced)

    • string
    • byte
  2. tf.train.FloatList (the following types can be coerced)

    • float (float32)
    • double (float64)
  3. tf.train.Int64List (the following types can be coerced)

    • bool
    • enum
    • int32
    • uint32
    • int64
    • uint64

In order to convert a standard TensorFlow type to a tf.Example-compatible tf.train.Feature, you can use the following shortcut functions:

Each function takes a scalar input value and returns a tf.train.Feature containing one of the three list types above.

# The following functions can be used to convert a value to a type compatible
# with tf.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

Below are some examples of how these functions work. Note the varying input types and the standardizes output types. If the input type for a function does not match one of the coercible types stated above, the function will raise an exception (e.g. _int64_feature(1.0) will error out, since 1.0 is a float, so should be used with the _float_feature function instead).

print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))

print(_float_feature(np.exp(1)))

print(_int64_feature(True))
print(_int64_feature(1))
bytes_list {
  value: "test_string"
}

bytes_list {
  value: "test_bytes"
}

float_list {
  value: 2.7182817459106445
}

int64_list {
  value: 1
}

int64_list {
  value: 1
}

All proto messages can be serialized to a binary-string using the .SerializeToString method.

feature = _float_feature(np.exp(1))

feature.SerializeToString()
b'\x12\x06\n\x04T\xf8-@'

Creating a tf.Example message

Suppose you want to create a tf.Example message from existing data. In practice, the dataset may come from anywhere, but the procedure of creating the tf.Example message from a single observation will be the same.

  1. Within each observation, each value needs to be converted to a tf.train.Feature containing one of the 3 compatible types, using one of the functions above.

  2. We create a map (dictionary) from the feature name string to the encoded feature value produced in #1.

  3. The map produced in #2 is converted to a Features message.

In this notebook, we will create a dataset using NumPy.

This dataset will have 4 features. - a boolean feature, False or True with equal probability - a random bytes feature, uniform across the entire support - an integer feature uniformly randomly chosen from [-10000, 10000) - a float feature from a standard normal distribution

Consider a sample consisting of 10,000 independently and identically distributed observations from each of the above distributions.

# the number of observations in the dataset
n_observations = int(1e4)

# boolean feature, encoded as False or True
feature0 = np.random.choice([False, True], n_observations)

# integer feature, random between -10000 and 10000
feature1 = np.random.randint(0, 5, n_observations)

# bytes feature
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]

# float feature, from a standard normal distribution
feature3 = np.random.randn(n_observations)

Each of these features can be coerced into a tf.Example-compatible type using one of _bytes_feature, _float_feature, _int64_feature. We can then create a tf.Example message from these encoded features.

def serialize_example(feature0, feature1, feature2, feature3):
  """
  Creates a tf.Example message ready to be written to a file.
  """
  
  # Create a dictionary mapping the feature name to the tf.Example-compatible
  # data type.
  
  feature = {
      'feature0': _int64_feature(feature0),
      'feature1': _int64_feature(feature1),
      'feature2': _bytes_feature(feature2),
      'feature3': _float_feature(feature3),
  }
  
  # Create a Features message using tf.train.Example.
  
  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

For example, suppose we have a single observation from the dataset, [False, 4, bytes('goat'), 0.9876]. We can create and print the tf.Example message for this observation using create_message(). Each single observation will be written as a Features message as per the above. Note that the tf.Example message is just a wrapper around the Features message.

# This is an example observation from the dataset.

example_observation = []

serialized_example = serialize_example(False, 4, b'goat', 0.9876)
serialized_example
b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?'

To decode the message use the tf.train.Example.FromString method.

example_proto = tf.train.Example.FromString(serialized_example)
example_proto
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 4
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "goat"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.9876000285148621
      }
    }
  }
}

TFRecord files using tf.data

The tf.data module also provides tools for reading and writing data in tensorflow.

Writing a TFRecord file

The easiest way to get the data into a dataset is to use the from_tensor_slices method.

Applied to an array, it returns a dataset of scalars.

tf.data.Dataset.from_tensor_slices(feature1)

Applies to a tuple of arrays, it returns a dataset of tuples:

features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))
features_dataset
# Use `take(1)` to only pull one example from the dataset.
for f0,f1,f2,f3 in features_dataset.take(1):
  print(f0)
  print(f1)
  print(f2)
  print(f3)
tf.Tensor(True, shape=(), dtype=bool)
tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(b'dog', shape=(), dtype=string)
tf.Tensor(0.15777874861525687, shape=(), dtype=float64)

Use the tf.data.Dataset.map method to apply a function to each element of a Dataset.

The mapped function must operate in TensorFlow graph mode: It must operate on and return tf.Tensors. A non-tensor function, like create_example, can be wrapped with tf.py_func to make it compatible.

Using tf.py_func requires that you specify the shape and type information that is otherwise unavailable:

def tf_serialize_example(f0,f1,f2,f3):
  tf_string = tf.py_func(
    serialize_example, 
    (f0,f1,f2,f3),  # pass these args to the above function.
    tf.string)      # the return type is <a href="../../api_docs/python/tf#string"><code>tf.string</code></a>.
  return tf.reshape(tf_string, ()) # The result is a scalar

Apply this function to each element in the dataset:

serialized_features_dataset = features_dataset.map(tf_serialize_example)
serialized_features_dataset

And write them to a TFRecord file:

filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)

Reading a TFRecord file

We can also read the TFRecord file using the tf.data.TFRecordDataset class.

More information on consuming TFRecord files using tf.data can be found here.

Using TFRecordDatasets can be useful for standardizing input data and optimizing performance.

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset

At this point the dataset contains serialized tf.train.Example messages. When iterated over it returns these as scalar string tensors.

Use the .take method to only show the first 10 records.

for raw_record in raw_dataset.take(10):
  print(repr(raw_record))
<tf.Tensor: id=43, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xc1\x90!>'>
<tf.Tensor: id=45, shape=(), dtype=string, numpy=b'\nS\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xb4p\x91?'>
<tf.Tensor: id=47, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xffY\xb9?'>
<tf.Tensor: id=49, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xe8*\xc8\xbe'>
<tf.Tensor: id=51, shape=(), dtype=string, numpy=b'\nQ\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04*L\x95?\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01'>
<tf.Tensor: id=53, shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfa\xd4\x0f\xbf'>
<tf.Tensor: id=55, shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xda\xd56?'>
<tf.Tensor: id=57, shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x046\xb3\xbe\xbe'>
<tf.Tensor: id=59, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04B\xf3m>'>
<tf.Tensor: id=61, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfc\xf1\xed>'>

These tensors can be parsed using the function below.

# Create a description of the features.  
feature_description = {
    'feature0': tf.FixedLenFeature([], tf.int64, default_value=0),
    'feature1': tf.FixedLenFeature([], tf.int64, default_value=0),
    'feature2': tf.FixedLenFeature([], tf.string, default_value=''),
    'feature3': tf.FixedLenFeature([], tf.float32, default_value=0.0),
}

def _parse_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
  return tf.parse_single_example(example_proto, feature_description)

Or use tf.parse example to parse a whole batch at once.

Apply this finction to each item in the dataset using the tf.data.Dataset.map method:

parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset 

Use eager execution to display the observations in the dataset. There are 10,000 observations in this dataset, but we only display the first 10. The data is displayed as a dictionary of features. Each item is a tf.Tensor, and the numpy element of this tensor displays the value of the feature.

for parsed_record in parsed_dataset.take(10):
  print(repr(raw_record))
<tf.Tensor: id=61, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfc\xf1\xed>'>
<tf.Tensor: id=61, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfc\xf1\xed>'>
<tf.Tensor: id=61, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfc\xf1\xed>'>
<tf.Tensor: id=61, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfc\xf1\xed>'>
<tf.Tensor: id=61, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfc\xf1\xed>'>
<tf.Tensor: id=61, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfc\xf1\xed>'>
<tf.Tensor: id=61, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfc\xf1\xed>'>
<tf.Tensor: id=61, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfc\xf1\xed>'>
<tf.Tensor: id=61, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfc\xf1\xed>'>
<tf.Tensor: id=61, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xfc\xf1\xed>'>

Here, the tf.parse_example function unpacks the tf.Example fields into standard tensors.

TFRecord files using tf.python_io

The tf.python_io module also contains pure-Python functions for reading and writing TFRecord files.

Writing a TFRecord file

Now write the 10,000 observations to the file test.tfrecords. Each observation is converted to a tf.Example message, then written to file. We can then verify that the file test.tfrecords has been created.

# Write the `tf.Example` observations to the file.
with tf.python_io.TFRecordWriter(filename) as writer:
  for i in range(n_observations):
    example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
    writer.write(example)
!ls
test.tfrecord

Reading a TFRecord file

Suppose we now want to read this data back, to be input as data into a model.

The following example imports the data as is, as a tf.Example message. This can be useful to verify that a the file contains the data that we expect. This can also be useful if the input data is stored as TFRecords but you would prefer to input NumPy data (or some other input data type), for example here, since this example allows us to read the values themselves.

We iterate through the TFRecords in the infile, extract the tf.Example message, and can read/store the values within.

record_iterator = tf.python_io.tf_record_iterator(path=filename)

for string_record in record_iterator:
  example = tf.train.Example()
  example.ParseFromString(string_record)
  
  print(example)
  
  # Exit after 1 iteration as this is purely demonstrative.
  break
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 1
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 1
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "dog"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.1577787548303604
      }
    }
  }
}

The features of the example object (created above of type tf.Example) can be accessed using its getters (similarly to any protocol buffer message). example.features returns a repeated feature message, then getting the feature message returns a map of feature name to feature value (stored in Python as a dictionary).

print(dict(example.features.feature))
{'feature1': int64_list {
  value: 1
}
, 'feature3': float_list {
  value: 0.1577787548303604
}
, 'feature2': bytes_list {
  value: "dog"
}
, 'feature0': int64_list {
  value: 1
}
}

From this dictionary, you can get any given value as with a dictionary.

print(example.features.feature['feature3'])
float_list {
  value: 0.1577787548303604
}

Now, we can access the value using the getters again.

print(example.features.feature['feature3'].float_list.value)
[0.1577787548303604]

Walkthrough: Reading/Writing Image Data

This is an example of how to read and write image data using TFRecords. The purpose of this is to show how, end to end, input data (in this case an image) and write the data as a TFRecord file, then read the file back and display the image.

This can be useful if, for example, you want to use several models on the same input dataset. Instead of storing the image data raw, it can be preprocessed into the TFRecords format, and that can be used in all further processing and modelling.

First, let's download this image of a cat in the snow and this photo of the Williamsburg Bridge, NYC under construction.

Fetch the images

cat_in_snow  = tf.keras.utils.get_file('320px-Felis_catus-cat_on_snow.jpg', 'https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file('194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg','https://upload.wikimedia.org/wikipedia/commons/thumb/f/fe/New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
Downloading data from https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/320px-Felis_catus-cat_on_snow.jpg
24576/17858 [=========================================] - 0s 0us/step
Downloading data from https://upload.wikimedia.org/wikipedia/commons/thumb/f/fe/New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg
16384/15477 [===============================] - 0s 0us/step
display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))

jpeg

Image cc-by: Von.grzanka

display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">source</a>'))

jpeg

source

Write the TFRecord file

As we did earlier, we can now encode the features as types compatible with tf.Example. In this case, we will not only store the raw image string as a feature, but we will store the height, width, depth, and an arbitrary label feature, which is used when we write the file to distinguish between the cat image and the bridge image. We will use 0 for the cat image, and 1 for the bridge image.

image_labels = {
    cat_in_snow : 0,
    williamsburg_bridge : 1,
}
# This is an example, just using the cat image.
image_string = open(cat_in_snow, 'rb').read()

label = image_labels[cat_in_snow]

# Create a dictionary with features that may be relevant.
def image_example(image_string, label):
  image_shape = tf.image.decode_jpeg(image_string).shape

  feature = {
      'height': _int64_feature(image_shape[0]),
      'width': _int64_feature(image_shape[1]),
      'depth': _int64_feature(image_shape[2]),
      'label': _int64_feature(label),
      'image_raw': _bytes_feature(image_string),
  }

  return tf.train.Example(features=tf.train.Features(feature=feature))

for line in str(image_example(image_string, label)).split('\n')[:15]:
  print(line)
print('...')
features {
  feature {
    key: "depth"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "height"
    value {
      int64_list {
        value: 213
      }
...

We see that all of the features are now stores in the tf.Example message. Now, we functionalize the code above and write the example messages to a file, images.tfrecords.

# Write the raw image files to images.tfrecords.
# First, process the two images into tf.Example messages.
# Then, write to a .tfrecords file.

with tf.python_io.TFRecordWriter('images.tfrecords') as writer:
  for filename, label in image_labels.items():
    image_string = open(filename, 'rb').read()
    tf_example = image_example(image_string, label)
    writer.write(tf_example.SerializeToString())
!ls
images.tfrecords  test.tfrecord

Read the TFRecord file

We now have the file images.tfrecords. We can now iterate over the records in the file to read back what we wrote. Since, for our use case we will just reproduce the image, the only feature we need is the raw image string. We can extract that using the getters described above, namely example.features.feature['image_raw'].bytes_list.value[0]. We also use the labels to determine which record is the cat as opposed to the bridge.

raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')

# Create a dictionary describing the features.  
image_feature_description = {
    'height': tf.FixedLenFeature([], tf.int64),
    'width': tf.FixedLenFeature([], tf.int64),
    'depth': tf.FixedLenFeature([], tf.int64),
    'label': tf.FixedLenFeature([], tf.int64),
    'image_raw': tf.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
  return tf.parse_single_example(example_proto, image_feature_description)

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset

Recover the images from the TFRecord file:

for image_features in parsed_image_dataset:
  image_raw = image_features['image_raw'].numpy()
  display.display(display.Image(data=image_raw))

jpeg

jpeg