Missed TensorFlow World? Check out the recap. Learn more

TFRecords と tf.Example の使用法

View on TensorFlow.org Run in Google Colab View source on GitHub

データの読み込みを効率的にするには、データをシリアライズし、連続的に読み込めるファイルのセット(各ファイルは 100-200MB)に保存することが有効です。データをネットワーク経由で流そうとする場合には、特にそうです。また、データの前処理をキャッシングする際にも役立ちます。

TFRecord 形式は、バイナリレコードのシリーズを保存するための単純な形式です。

プロトコルバッファ は、構造化データを効率的にシリアライズする、プラットフォームや言語に依存しないライブラリです。

プロトコルメッセージは .proto という拡張子のファイルで表されます。メッセージの型を識別するもっとも簡単な方法です。

tf.Example メッセージ(あるいはプロトコルバッファ)は、{"string": value} 形式のマッピングを表現する柔軟なメッセージ型です。これは、TensorFlow 用に設計され、TFX のような上位レベルの API で共通に使用されています。

このノートブックでは、tf.Example メッセージの作成、解析と使用法をデモし、その後、tf.Example メッセージをシリアライズして .tfrecord に書き出し、その後読み取る方法を示します。

注:こうした構造は有用ですが必ずそうしなければならなというものではありません。tf.data を使っていて、それでもデータの読み込みが訓練のボトルネックであるという場合でなければ、既存のコードを TFRecords を使用するために変更する必要はありません。データセットの性能改善のヒントは、 Data Input Pipeline Performance を参照ください。

設定

from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

import numpy as np
import IPython.display as display

tf.Example

tf.Example 用のデータ型

基本的には tf.Example{"string": tf.train.Feature} というマッピングです。

tf.train.Feature メッセージ型は次の3つの型のうち1つをとることができます。(.proto file を参照)このほかの一般的なデータ型のほとんどは、強制的にこれらのうちの1つにすること可能です。

  1. tf.train.BytesList (次の型のデータを扱うことが可能)

    • string
    • byte
  2. tf.train.FloatList (次の型のデータを扱うことが可能)

    • float (float32)
    • double (float64)
  3. tf.train.Int64List (次の型のデータを扱うことが可能)

    • bool
    • enum
    • int32
    • uint32
    • int64
    • uint64

通常の TensorFlow の型を tf.Example 互換の tf.train.Feature に変換するには、次のショートカット関数を使うことができます。

どの関数も、1個のスカラー値を入力とし、上記の3つの list 型のうちの一つを含む tf.train.Feature を返します。

# 下記の関数を使うと値を tf.Example と互換性の有る型に変換できる

def _bytes_feature(value):
  """string / byte 型から byte_list を返す"""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """float / double 型から float_list を返す"""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """bool / enum / int / uint 型から Int64_list を返す"""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

注:単純化のため、このサンプルではスカラー値の入力のみを扱っています。スカラー値ではない特徴を扱うもっとも簡単な方法は、tf.serialize_tensor を使ってテンソルをバイナリ文字列に変換する方法です。TensorFlow では文字列はスカラー値として扱います。バイナリ文字列をテンソルに戻すには、tf.parse_tensor を使用します。

上記の関数の使用例を下記に示します。入力がさまざまな型であるのに対して、出力が標準化されていることに注目してください。入力が、強制変換できない型であった場合、例外が発生します。(例: _int64_feature(1.0) はエラーとなります。1.0 が浮動小数点数であるためで、代わりに _float_feature 関数を使用すべきです)

print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))

print(_float_feature(np.exp(1)))

print(_int64_feature(True))
print(_int64_feature(1))
bytes_list {
  value: "test_string"
}

bytes_list {
  value: "test_bytes"
}

float_list {
  value: 2.7182817459106445
}

int64_list {
  value: 1
}

int64_list {
  value: 1
}

主要なメッセージはすべて .SerializeToString を使ってバイナリ文字列にシリアライズすることができます。

feature = _float_feature(np.exp(1))

feature.SerializeToString()
b'\x12\x06\n\x04T\xf8-@'

tf.Example メッセージの作成

既存のデータから tf.Example を作成したいとします。実際には、データセットの出処はどこでもよいのですが、1件の観測記録からtf.Example メッセージを作る手順はおなじです。

  1. 観測記録それぞれにおいて、各値は上記の関数を使って3種類の互換性のある型からなる tf.train.Feature に変換する必要があります。

  2. 次に、特徴の名前を表す文字列と、#1 で作ったエンコード済みの特徴量を対応させたマップ(ディクショナリ)を作成します。

  3. 2 で作成したマップを特徴量メッセージに変換します。

このノートブックでは、NumPy を使ってデータセットを作成します。

このデータセットには4つの特徴量があります。 - False または True を表す論理値。出現確率は等しいものとします。 - ランダムなバイト値。全体において一様であるとします。 - [-10000, 10000) の範囲から一様にサンプリングした整数値。 - 標準正規分布からサンプリングした浮動小数点数。

サンプルは上記の分布から独立しておなじ様に分布した10,000件の観測記録からなるものとします。

# データセットに含まれる観測結果の件数
n_observations = int(1e4)

# ブール特徴量 False または True としてエンコードされている
feature0 = np.random.choice([False, True], n_observations)

# 整数特徴量  -10000 から 10000 の間の乱数
feature1 = np.random.randint(0, 5, n_observations)

# バイト特徴量
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]

# 浮動小数点数特徴量 標準正規分布から発生
feature3 = np.random.randn(n_observations)

これらの特徴量は、_bytes_feature, _float_feature, _int64_feature のいずれかを使って tf.Example 互換の型に強制変換されます。その後、エンコード済みの特徴量から tf.Example メッセージを作成できます。

def serialize_example(feature0, feature1, feature2, feature3):
  """
  ファイル出力可能な tf.Example メッセージを作成する
  """
  
  # 特徴量名と tf.Example 互換データ型との対応ディクショナリを作成
  
  feature = {
      'feature0': _int64_feature(feature0),
      'feature1': _int64_feature(feature1),
      'feature2': _bytes_feature(feature2),
      'feature3': _float_feature(feature3),
  }
  
  # tf.train.Example を用いて特徴メッセージを作成
  
  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

たとえば、データセットに [False, 4, bytes('goat'), 0.9876] という1つの観測記録があるとします。create_message() を使うとこの観測記録から tf.Example メッセージを作成し印字できます。上記のように、観測記録一つ一つが Features メッセージとして書かれています。tf.Example メッセージは、この Features メッセージを包むラッパーに過ぎないことに注意してください。

# データセットからの観測記録の例

example_observation = []

serialized_example = serialize_example(False, 4, b'goat', 0.9876)
serialized_example
b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?'

メッセージをデコードするには、tf.train.Example.FromString メソッドを使用します。

example_proto = tf.train.Example.FromString(serialized_example)
example_proto
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 4
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "goat"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.9876000285148621
      }
    }
  }
}

TFRecord フォーマットの詳細

TFRecord ファイルにはレコードのシーケンスが含まれます。このファイルはシーケンシャル読み取りのみが可能です。

それぞれのレコードには、データを格納するためのバイト文字列とデータ長、そして整合性チェックのための CRC32C(Castagnoli 多項式を使った 32 ビットの CRC )ハッシュ値が含まれます。

各レコードのフォーマットは

uint64 長さ
uint32 長さのマスク済み crc32 ハッシュ値
byte   data[長さ]
uint32 データのマスク済み crc32 ハッシュ値

複数のレコードが結合されてファイルを構成します。CRC についてはここに說明があります。CRC のマスクは下記のとおりです。

masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul

注:TFRecord ファイルを作るのに、tf.Example を使わなければならないということはありません。tf.Example は、ディクショナリをバイト文字列にシリアライズする方法の1つです。エンコードされた画像データや、(tf.io.serialize_tensor を使ってシリアライズされ、tf.io.parse_tensor で読み込まれる)シリアライズされたテンソルもあります。そのほかのオプションについては、tf.io モジュールを参照してください。

tf.data を使用した TFRecord ファイル

tf.data モジュールには、TensorFlow でデータを読み書きするツールが含まれます。

TFRecord ファイルの書き出し

データをデータセットにするもっとも簡単な方法は from_tensor_slices メソッドです。

配列に適用すると、このメソッドはスカラー値のデータセットを返します。

tf.data.Dataset.from_tensor_slices(feature1)
<TensorSliceDataset shapes: (), types: tf.int64>

配列のタプルに適用すると、タプルのデータセットが返されます。

features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))
features_dataset
<TensorSliceDataset shapes: ((), (), (), ()), types: (tf.bool, tf.int64, tf.string, tf.float64)>
#  データセットから1つのサンプルだけを取り出すには `take(1)` を使います。
for f0,f1,f2,f3 in features_dataset.take(1):
  print(f0)
  print(f1)
  print(f2)
  print(f3)
tf.Tensor(False, shape=(), dtype=bool)
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(b'cat', shape=(), dtype=string)
tf.Tensor(-0.6229937096344282, shape=(), dtype=float64)

Dataset のそれぞれの要素に関数を適用するには、tf.data.Dataset.map メソッドを使用します。

マップされる関数は TensorFlow のグラフモードで動作する必要があります。関数は tf.Tensors を処理し、返す必要があります。create_example のような非テンソル関数は、互換性のため tf.py_func でラップすることができます。

tf.py_func を使用する際には、シェイプと型は取得できないため、指定する必要があります。

def tf_serialize_example(f0,f1,f2,f3):
  tf_string = tf.py_function(
    serialize_example, 
    (f0,f1,f2,f3),  # 上記の関数にこれらの引数を渡す
    tf.string)      # 戻り値の型は tf.string
  return tf.reshape(tf_string, ()) # 結果はスカラー
tf_serialize_example(f0,f1,f2,f3)
<tf.Tensor: id=30, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x84|\x1f\xbf'>

この関数をデータセットのそれぞれの要素に適用します。

serialized_features_dataset = features_dataset.map(tf_serialize_example)
serialized_features_dataset
<MapDataset shapes: (), types: tf.string>
def generator():
  for features in features_dataset:
    yield serialize_example(*features)
serialized_features_dataset = tf.data.Dataset.from_generator(
    generator, output_types=tf.string, output_shapes=())
serialized_features_dataset
<DatasetV1Adapter shapes: (), types: tf.string>

TFRecord ファイルに書き出します。

filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)

TFRecord ファイルの読み込み

tf.data.TFRecordDataset クラスを使って TFRecord ファイルを読み込むこともできます。

tf.data を使って TFRecord ファイルを取り扱う際の詳細については、こちらを参照ください。

TFRecordDataset を使うことは、入力データを標準化し、パフォーマンスを最適化するのに有用です。

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>

この時点で、データセットにはシリアライズされた tf.train.Example メッセージが含まれています。データセットをイテレートすると、スカラーの文字列テンソルが返ってきます。

.take メソッドを使って最初の 10 レコードだけを表示します。

注:tf.data.Dataset をイテレートできるのは、Eager Execution が有効になっている場合のみです。

for raw_record in raw_dataset.take(10):
    print(repr(raw_record))
<tf.Tensor: id=50092, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x84|\x1f\xbf'>
<tf.Tensor: id=50093, shape=(), dtype=string, numpy=b'\nS\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04o\xe8\x89\xbc'>
<tf.Tensor: id=50094, shape=(), dtype=string, numpy=b'\nS\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xe2\xfc\xd4\xbf'>
<tf.Tensor: id=50095, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xffn\x80?'>
<tf.Tensor: id=50096, shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x9d\x13\xd6>'>
<tf.Tensor: id=50097, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04b*=?'>
<tf.Tensor: id=50098, shape=(), dtype=string, numpy=b'\nQ\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xf8\x02\xce\xbe\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00'>
<tf.Tensor: id=50099, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x845\xe8>'>
<tf.Tensor: id=50100, shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04I\xbbO\xbe'>
<tf.Tensor: id=50101, shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04R\xca\x91?'>

これらのテンソルは下記の関数でパースできます。

注:ここでは、feature_description が必要です。データセットはグラフ実行を使用するため、この記述を使ってシェイプと型を構築するのです。

# 特徴の記述
feature_description = {
    'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
}

def _parse_function(example_proto):
  # 上記の記述を使って入力の tf.Example を処理
  return tf.io.parse_single_example(example_proto, feature_description)

あるいは、tf.parse example を使ってバッチ全体を一度にパースします。

tf.data.Dataset.map メソッドを使って、データセットの各アイテムにこの関数を適用します。

parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset 
<MapDataset shapes: {feature1: (), feature2: (), feature0: (), feature3: ()}, types: {feature1: tf.int64, feature2: tf.string, feature0: tf.int64, feature3: tf.float32}>

Eager Execution を使ってデータセット中の観測記録を表示します。このデータセットには 10,000 件の観測記録がありますが、最初の 10 個だけ表示します。
データは特徴量のディクショナリの形で表示されます。それぞれの項目は tf.Tensor であり、このテンソルの numpy 要素は特徴量を表します。

for parsed_record in parsed_dataset.take(10):
  print(repr(raw_record))
<tf.Tensor: id=50101, shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04R\xca\x91?'>
<tf.Tensor: id=50101, shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04R\xca\x91?'>
<tf.Tensor: id=50101, shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04R\xca\x91?'>
<tf.Tensor: id=50101, shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04R\xca\x91?'>
<tf.Tensor: id=50101, shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04R\xca\x91?'>
<tf.Tensor: id=50101, shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04R\xca\x91?'>
<tf.Tensor: id=50101, shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04R\xca\x91?'>
<tf.Tensor: id=50101, shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04R\xca\x91?'>
<tf.Tensor: id=50101, shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04R\xca\x91?'>
<tf.Tensor: id=50101, shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04R\xca\x91?'>

ここでは、tf.parse_exampletf.Example のフィールドを通常のテンソルに展開しています。

tf.python_io を使った TFRecord ファイル

tf.python_io モジュールには、TFRecord ファイルの読み書きのための純粋な Python 関数も含まれています。

TFRecord ファイルの書き出し

次にこの 10,000 件の観測記録を test.tfrecords ファイルに出力します。観測記録はそれぞれ tf.Example メッセージに変換され、ファイルに出力されます。その後、test.tfrecords ファイルが作成されたことを確認することができます。

# `tf.Example` 観測記録をファイルに出力
with tf.io.TFRecordWriter(filename) as writer:
  for i in range(n_observations):
    example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
    writer.write(example)
!du -sh {filename}
984K    test.tfrecord

TFRecord ファイルの読み込み

これらのシリアライズされたテンソルは、tf.train.Example.ParseFromString を使って簡単にパースできます。

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>
for raw_record in raw_dataset.take(1):
  example = tf.train.Example()
  example.ParseFromString(raw_record.numpy())
  print(example)
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "cat"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: -0.6229937076568604
      }
    }
  }
}

ウォークスルー: 画像データの読み書き

以下は、TFRecord を使って画像データを読み書きする方法の例です。この例の目的は、データ(この場合は画像)を入力し、そのデータを TFRecord ファイルに書き込んで、再びそのファイルを読み込み、画像を表示するという手順を最初から最後まで示すことです。

これは、たとえば、おなじ入力データセットを使って複数のモデルを構築するといった場合に役立ちます。画像データをそのまま保存する代わりに、TFRecord 形式に前処理しておき、その後の処理やモデル構築に使用することができます。

まずは、雪の中の猫の画像と、ニューヨーク市にあるウイリアムズバーグ橋の 写真をダウンロードしましょう。

画像の取得

cat_in_snow  = tf.keras.utils.get_file('320px-Felis_catus-cat_on_snow.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file('194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg
24576/17858 [=========================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg
16384/15477 [===============================] - 0s 0us/step
display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))

jpeg

display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">From Wikimedia</a>'))

jpeg

TFRecord ファイルの書き出し

上記で行ったように、この特徴量を tf.Example と互換のデータ型にエンコードできます。この場合には、生の画像文字列を特徴として保存するだけではなく、縦、横のサイズにチャネル数、更に画像を保存する際に猫の画像と橋の画像を区別するための label 特徴量を付け加えます。猫の画像には 0 を、橋の画像には 1 を使うことにしましょう。

image_labels = {
    cat_in_snow : 0,
    williamsburg_bridge : 1,
}
# 猫の画像を使った例
image_string = open(cat_in_snow, 'rb').read()

label = image_labels[cat_in_snow]

# 関連する特徴量のディクショナリを作成
def image_example(image_string, label):
  image_shape = tf.image.decode_jpeg(image_string).shape

  feature = {
      'height': _int64_feature(image_shape[0]),
      'width': _int64_feature(image_shape[1]),
      'depth': _int64_feature(image_shape[2]),
      'label': _int64_feature(label),
      'image_raw': _bytes_feature(image_string),
  }

  return tf.train.Example(features=tf.train.Features(feature=feature))

for line in str(image_example(image_string, label)).split('\n')[:15]:
  print(line)
print('...')
features {
  feature {
    key: "depth"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "height"
    value {
      int64_list {
        value: 213
      }
...

ご覧のように、すべての特徴量が tf.Example メッセージに保存されました。上記のコードを関数化し、このサンプルメッセージを images.tfrecords ファイルに書き込みます。

# 生の画像を images.tfrecords ファイルに書き出す
# まず、2つの画像を tf.Example メッセージに変換し、
# 次に .tfrecords ファイルに書き出す
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
  for filename, label in image_labels.items():
    image_string = open(filename, 'rb').read()
    tf_example = image_example(image_string, label)
    writer.write(tf_example.SerializeToString())
!du -sh {record_file}
36K images.tfrecords

TFRecord ファイルの読み込み

これで、images.tfrecords ファイルができました。このファイルの中のレコードをイテレートし、書き込んだものを読み出します。このユースケースでは、画像を復元するだけなので、必要なのは生画像の文字列だけです。上記のゲッター、すなわち、example.features.feature['image_raw'].bytes_list.value[0] を使って抽出することができます。猫と橋のどちらであるかを決めるため、ラベルも使用します。

raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')

# 特徴量を記述するディクショナリを作成
image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  # 入力の tf.Example のプロトコルバッファを上記のディクショナリを使って解釈
  return tf.io.parse_single_example(example_proto, image_feature_description)

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset
<MapDataset shapes: {depth: (), height: (), image_raw: (), label: (), width: ()}, types: {depth: tf.int64, height: tf.int64, image_raw: tf.string, label: tf.int64, width: tf.int64}>

TFRecord ファイルから画像を復元しましょう。

for image_features in parsed_image_dataset:
  image_raw = image_features['image_raw'].numpy()
  display.display(display.Image(data=image_raw))

jpeg

jpeg