![]() |
![]() |
![]() |
![]() |
データの読み込みを効率的にするには、データをシリアライズし、連続的に読み込めるファイルのセット(各ファイルは 100-200MB)に保存することが有効です。データをネットワーク経由で流そうとする場合には、特にそうです。また、データの前処理をキャッシングする際にも役立ちます。
TFRecord 形式は、バイナリレコードのシリーズを保存するための単純な形式です。
プロトコルバッファ は、構造化データを効率的にシリアライズする、プラットフォームや言語に依存しないライブラリです。
プロトコルメッセージは .proto
という拡張子のファイルで表されます。メッセージの型を識別するもっとも簡単な方法です。
tf.Example
メッセージ(あるいはプロトコルバッファ)は、{"string": value}
形式のマッピングを表現する柔軟なメッセージ型です。これは、TensorFlow 用に設計され、TFX のような上位レベルの API で共通に使用されています。
このノートブックでは、tf.Example
メッセージの作成、解析と使用法をデモし、その後、tf.Example
メッセージをシリアライズして .tfrecord
に書き出し、その後読み取る方法を示します。
注:こうした構造は有用ですが必ずそうしなければならなというものではありません。tf.data
を使っていて、それでもデータの読み込みが訓練のボトルネックであるという場合でなければ、既存のコードを TFRecords を使用するために変更する必要はありません。データセットの性能改善のヒントは、 Data Input Pipeline Performance を参照ください。
設定
import tensorflow as tf
import numpy as np
import IPython.display as display
tf.Example
tf.Example
用のデータ型
基本的には tf.Example
は {"string": tf.train.Feature}
というマッピングです。
tf.train.Feature
メッセージ型は次の3つの型のうち1つをとることができます。(.proto file を参照)このほかの一般的なデータ型のほとんどは、強制的にこれらのうちの1つにすること可能です。
tf.train.BytesList
(次の型のデータを扱うことが可能)string
byte
tf.train.FloatList
(次の型のデータを扱うことが可能)float
(float32
)double
(float64
)
tf.train.Int64List
(次の型のデータを扱うことが可能)bool
enum
int32
uint32
int64
uint64
通常の TensorFlow の型を tf.Example
互換の tf.train.Feature
に変換するには、次のショートカット関数を使うことができます。
どの関数も、1個のスカラー値を入力とし、上記の3つの list
型のうちの一つを含む tf.train.Feature
を返します。
# 下記の関数を使うと値を tf.Example と互換性の有る型に変換できる
def _bytes_feature(value):
"""string / byte 型から byte_list を返す"""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""float / double 型から float_list を返す"""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""bool / enum / int / uint 型から Int64_list を返す"""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
注:単純化のため、このサンプルではスカラー値の入力のみを扱っています。スカラー値ではない特徴を扱うもっとも簡単な方法は、tf.serialize_tensor
を使ってテンソルをバイナリ文字列に変換する方法です。TensorFlow では文字列はスカラー値として扱います。バイナリ文字列をテンソルに戻すには、tf.parse_tensor
を使用します。
上記の関数の使用例を下記に示します。入力がさまざまな型であるのに対して、出力が標準化されていることに注目してください。入力が、強制変換できない型であった場合、例外が発生します。(例: _int64_feature(1.0)
はエラーとなります。1.0
が浮動小数点数であるためで、代わりに _float_feature
関数を使用すべきです)
print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))
print(_float_feature(np.exp(1)))
print(_int64_feature(True))
print(_int64_feature(1))
bytes_list { value: "test_string" } bytes_list { value: "test_bytes" } float_list { value: 2.7182817459106445 } int64_list { value: 1 } int64_list { value: 1 }
主要なメッセージはすべて .SerializeToString
を使ってバイナリ文字列にシリアライズすることができます。
feature = _float_feature(np.exp(1))
feature.SerializeToString()
b'\x12\x06\n\x04T\xf8-@'
tf.Example
メッセージの作成
既存のデータから tf.Example
を作成したいとします。実際には、データセットの出処はどこでもよいのですが、1件の観測記録からtf.Example
メッセージを作る手順はおなじです。
観測記録それぞれにおいて、各値は上記の関数を使って3種類の互換性のある型からなる
tf.train.Feature
に変換する必要があります。次に、特徴の名前を表す文字列と、#1 で作ったエンコード済みの特徴量を対応させたマップ(ディクショナリ)を作成します。
ステップ2 で作成したマップを特徴量メッセージに変換します。
このノートブックでは、NumPy を使ってデータセットを作成します。
このデータセットには4つの特徴量があります。
False
またはTrue
を表す論理値。出現確率は等しいものとします。- ランダムなバイト値。全体において一様であるとします。
[-10000, 10000)
の範囲から一様にサンプリングした整数値。- 標準正規分布からサンプリングした浮動小数点数。
サンプルは上記の分布から独立しておなじ様に分布した10,000件の観測記録からなるものとします。
# データセットに含まれる観測結果の件数
n_observations = int(1e4)
# ブール特徴量 False または True としてエンコードされている
feature0 = np.random.choice([False, True], n_observations)
# 整数特徴量 -10000 から 10000 の間の乱数
feature1 = np.random.randint(0, 5, n_observations)
# バイト特徴量
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]
# 浮動小数点数特徴量 標準正規分布から発生
feature3 = np.random.randn(n_observations)
これらの特徴量は、_bytes_feature
, _float_feature
, _int64_feature
のいずれかを使って tf.Example
互換の型に強制変換されます。その後、エンコード済みの特徴量から tf.Example
メッセージを作成できます。
def serialize_example(feature0, feature1, feature2, feature3):
"""
ファイル出力可能な tf.Example メッセージを作成する
"""
# 特徴量名と tf.Example 互換データ型との対応ディクショナリを作成
feature = {
'feature0': _int64_feature(feature0),
'feature1': _int64_feature(feature1),
'feature2': _bytes_feature(feature2),
'feature3': _float_feature(feature3),
}
# tf.train.Example を用いて特徴メッセージを作成
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
たとえば、データセットに [False, 4, bytes('goat'), 0.9876]
という1つの観測記録があるとします。create_message()
を使うとこの観測記録から tf.Example
メッセージを作成し印字できます。上記のように、観測記録一つ一つが Features
メッセージとして書かれています。tf.Example
メッセージは、この Features
メッセージを包むラッパーに過ぎないことに注意してください。
# データセットからの観測記録の例
example_observation = []
serialized_example = serialize_example(False, 4, b'goat', 0.9876)
serialized_example
b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?'
メッセージをデコードするには、tf.train.Example.FromString
メソッドを使用します。
example_proto = tf.train.Example.FromString(serialized_example)
example_proto
features { feature { key: "feature0" value { int64_list { value: 0 } } } feature { key: "feature1" value { int64_list { value: 4 } } } feature { key: "feature2" value { bytes_list { value: "goat" } } } feature { key: "feature3" value { float_list { value: 0.9876000285148621 } } } }
TFRecord フォーマットの詳細
TFRecord ファイルにはレコードのシーケンスが含まれます。このファイルはシーケンシャル読み取りのみが可能です。
それぞれのレコードには、データを格納するためのバイト文字列とデータ長、そして整合性チェックのための CRC32C(Castagnoli 多項式を使った 32 ビットの CRC )ハッシュ値が含まれます。
各レコードのフォーマットは
uint64 長さ
uint32 長さのマスク済み crc32 ハッシュ値
byte data[長さ]
uint32 データのマスク済み crc32 ハッシュ値
複数のレコードが結合されてファイルを構成します。CRC についてはここに說明があります。CRC のマスクは下記のとおりです。
masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul
注:TFRecord ファイルを作るのに、tf.Example
を使わなければならないということはありません。tf.Example は、ディクショナリをバイト文字列にシリアライズする方法の1つです。エンコードされた画像データや、(tf.io.serialize_tensor
を使ってシリアライズされ、tf.io.parse_tensor
で読み込まれる)シリアライズされたテンソルもあります。そのほかのオプションについては、tf.io
モジュールを参照してください。
tf.data
を使用した TFRecord ファイル
tf.data
モジュールには、TensorFlow でデータを読み書きするツールが含まれます。
TFRecord ファイルの書き出し
データをデータセットにするもっとも簡単な方法は from_tensor_slices
メソッドです。
配列に適用すると、このメソッドはスカラー値のデータセットを返します。
tf.data.Dataset.from_tensor_slices(feature1)
<TensorSliceDataset shapes: (), types: tf.int64>
配列のタプルに適用すると、タプルのデータセットが返されます。
features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))
features_dataset
<TensorSliceDataset shapes: ((), (), (), ()), types: (tf.bool, tf.int64, tf.string, tf.float64)>
# データセットから1つのサンプルだけを取り出すには `take(1)` を使います。
for f0,f1,f2,f3 in features_dataset.take(1):
print(f0)
print(f1)
print(f2)
print(f3)
tf.Tensor(False, shape=(), dtype=bool) tf.Tensor(1, shape=(), dtype=int64) tf.Tensor(b'dog', shape=(), dtype=string) tf.Tensor(-1.8443852481470147, shape=(), dtype=float64)
Dataset
のそれぞれの要素に関数を適用するには、tf.data.Dataset.map
メソッドを使用します。
マップされる関数は TensorFlow のグラフモードで動作する必要があります。関数は tf.Tensors
を処理し、返す必要があります。create_example
のような非テンソル関数は、互換性のため tf.py_func
でラップすることができます。
tf.py_func
を使用する際には、シェイプと型は取得できないため、指定する必要があります。
def tf_serialize_example(f0,f1,f2,f3):
tf_string = tf.py_function(
serialize_example,
(f0,f1,f2,f3), # 上記の関数にこれらの引数を渡す
tf.string) # 戻り値の型は tf.string
return tf.reshape(tf_string, ()) # 結果はスカラー
tf_serialize_example(f0,f1,f2,f3)
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xd1\x14\xec\xbf\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01'>
この関数をデータセットのそれぞれの要素に適用します。
serialized_features_dataset = features_dataset.map(tf_serialize_example)
serialized_features_dataset
<MapDataset shapes: (), types: tf.string>
def generator():
for features in features_dataset:
yield serialize_example(*features)
serialized_features_dataset = tf.data.Dataset.from_generator(
generator, output_types=tf.string, output_shapes=())
serialized_features_dataset
<FlatMapDataset shapes: (), types: tf.string>
TFRecord ファイルに書き出します。
filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)
TFRecord ファイルの読み込み
tf.data.TFRecordDataset
クラスを使って TFRecord ファイルを読み込むこともできます。
tf.data
を使って TFRecord ファイルを取り扱う際の詳細については、こちらを参照ください。
TFRecordDataset
を使うことは、入力データを標準化し、パフォーマンスを最適化するのに有用です。
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>
この時点で、データセットにはシリアライズされた tf.train.Example
メッセージが含まれています。データセットをイテレートすると、スカラーの文字列テンソルが返ってきます。
.take
メソッドを使って最初の 10 レコードだけを表示します。
注:tf.data.Dataset
をイテレートできるのは、Eager Execution が有効になっている場合のみです。
for raw_record in raw_dataset.take(10):
print(repr(raw_record))
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xd1\x14\xec\xbf'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xd6\x1bb?'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xf4\x95\x8e>'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x1f\x96\xc9\xbd'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04c\x89\xf3?'> <tf.Tensor: shape=(), dtype=string, numpy=b"\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04o'\x9e?"> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04$&\x1a?'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xeb\xed\xc6\xbe'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x00\x82\xd0\xbf\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\rN\xd2\xbe'>
これらのテンソルは下記の関数でパースできます。
注:ここでは、feature_description
が必要です。データセットはグラフ実行を使用するため、この記述を使ってシェイプと型を構築するのです。
# 特徴の記述
feature_description = {
'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),
'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
}
def _parse_function(example_proto):
# 上記の記述を使って入力の tf.Example を処理
return tf.io.parse_single_example(example_proto, feature_description)
あるいは、tf.parse example
を使ってバッチ全体を一度にパースします。
tf.data.Dataset.map
メソッドを使って、データセットの各アイテムにこの関数を適用します。
parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset
<MapDataset shapes: {feature0: (), feature1: (), feature2: (), feature3: ()}, types: {feature0: tf.int64, feature1: tf.int64, feature2: tf.string, feature3: tf.float32}>
Eager Execution を使ってデータセット中の観測記録を表示します。このデータセットには 10,000 件の観測記録がありますが、最初の 10 個だけ表示します。
データは特徴量のディクショナリの形で表示されます。それぞれの項目は tf.Tensor
であり、このテンソルの numpy
要素は特徴量を表します。
for parsed_record in parsed_dataset.take(10):
print(repr(raw_record))
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\rN\xd2\xbe'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\rN\xd2\xbe'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\rN\xd2\xbe'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\rN\xd2\xbe'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\rN\xd2\xbe'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\rN\xd2\xbe'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\rN\xd2\xbe'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\rN\xd2\xbe'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\rN\xd2\xbe'> <tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\rN\xd2\xbe'>
ここでは、tf.parse_example
がtf.Example
のフィールドを通常のテンソルに展開しています。
tf.python_io を使った TFRecord ファイル
tf.python_io
モジュールには、TFRecord ファイルの読み書きのための純粋な Python 関数も含まれています。
TFRecord ファイルの書き出し
次にこの 10,000 件の観測記録を test.tfrecords
ファイルに出力します。観測記録はそれぞれ tf.Example
メッセージに変換され、ファイルに出力されます。その後、test.tfrecords
ファイルが作成されたことを確認することができます。
# `tf.Example` 観測記録をファイルに出力
with tf.io.TFRecordWriter(filename) as writer:
for i in range(n_observations):
example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
writer.write(example)
du -sh {filename}
984K test.tfrecord
TFRecord ファイルの読み込み
これらのシリアライズされたテンソルは、tf.train.Example.ParseFromString
を使って簡単にパースできます。
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>
for raw_record in raw_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
print(example)
features { feature { key: "feature0" value { int64_list { value: 0 } } } feature { key: "feature1" value { int64_list { value: 1 } } } feature { key: "feature2" value { bytes_list { value: "dog" } } } feature { key: "feature3" value { float_list { value: -1.8443852663040161 } } } }
ウォークスルー: 画像データの読み書き
以下は、TFRecord を使って画像データを読み書きする方法の例です。この例の目的は、データ(この場合は画像)を入力し、そのデータを TFRecord ファイルに書き込んで、再びそのファイルを読み込み、画像を表示するという手順を最初から最後まで示すことです。
これは、たとえば、おなじ入力データセットを使って複数のモデルを構築するといった場合に役立ちます。画像データをそのまま保存する代わりに、TFRecord 形式に前処理しておき、その後の処理やモデル構築に使用することができます。
まずは、雪の中の猫の画像と、ニューヨーク市にあるウイリアムズバーグ橋の 写真をダウンロードしましょう。
画像の取得
cat_in_snow = tf.keras.utils.get_file('320px-Felis_catus-cat_on_snow.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file('194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg 24576/17858 [=========================================] - 0s 0us/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg 16384/15477 [===============================] - 0s 0us/step
display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))
display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">From Wikimedia</a>'))
TFRecord ファイルの書き出し
上記で行ったように、この特徴量を tf.Example
と互換のデータ型にエンコードできます。この場合には、生の画像文字列を特徴として保存するだけではなく、縦、横のサイズにチャネル数、更に画像を保存する際に猫の画像と橋の画像を区別するための label
特徴量を付け加えます。猫の画像には 0
を、橋の画像には 1
を使うことにしましょう。
image_labels = {
cat_in_snow : 0,
williamsburg_bridge : 1,
}
# 猫の画像を使った例
image_string = open(cat_in_snow, 'rb').read()
label = image_labels[cat_in_snow]
# 関連する特徴量のディクショナリを作成
def image_example(image_string, label):
image_shape = tf.image.decode_jpeg(image_string).shape
feature = {
'height': _int64_feature(image_shape[0]),
'width': _int64_feature(image_shape[1]),
'depth': _int64_feature(image_shape[2]),
'label': _int64_feature(label),
'image_raw': _bytes_feature(image_string),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
for line in str(image_example(image_string, label)).split('\n')[:15]:
print(line)
print('...')
features { feature { key: "depth" value { int64_list { value: 3 } } } feature { key: "height" value { int64_list { value: 213 } ...
ご覧のように、すべての特徴量が tf.Example
メッセージに保存されました。上記のコードを関数化し、このサンプルメッセージを images.tfrecords
ファイルに書き込みます。
# 生の画像を images.tfrecords ファイルに書き出す
# まず、2つの画像を tf.Example メッセージに変換し、
# 次に .tfrecords ファイルに書き出す
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
for filename, label in image_labels.items():
image_string = open(filename, 'rb').read()
tf_example = image_example(image_string, label)
writer.write(tf_example.SerializeToString())
du -sh {record_file}
36K images.tfrecords
TFRecord ファイルの読み込み
これで、images.tfrecords
ファイルができました。このファイルの中のレコードをイテレートし、書き込んだものを読み出します。このユースケースでは、画像を復元するだけなので、必要なのは生画像の文字列だけです。上記のゲッター、すなわち、example.features.feature['image_raw'].bytes_list.value[0]
を使って抽出することができます。猫と橋のどちらであるかを決めるため、ラベルも使用します。
raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')
# 特徴量を記述するディクショナリを作成
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
def _parse_image_function(example_proto):
# 入力の tf.Example のプロトコルバッファを上記のディクショナリを使って解釈
return tf.io.parse_single_example(example_proto, image_feature_description)
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset
<MapDataset shapes: {depth: (), height: (), image_raw: (), label: (), width: ()}, types: {depth: tf.int64, height: tf.int64, image_raw: tf.string, label: tf.int64, width: tf.int64}>
TFRecord ファイルから画像を復元しましょう。
for image_features in parsed_image_dataset:
image_raw = image_features['image_raw'].numpy()
display.display(display.Image(data=image_raw))