MLコミュニティデーは11月9日です! TensorFlow、JAXからの更新のために私たちに参加し、より多くの詳細をご覧ください

TFDSと決定論

TensorFlow.orgで表示 GoogleColabで実行 GitHubで表示 ノートブックをダウンロードする

このドキュメントの説明:

  • TFDSは決定論を保証します
  • TFDSはどの順序で例を読み取りますか
  • さまざまな警告と落とし穴

設定

データセット

TFDSがデータを読み取る方法を理解するには、ある程度のコンテキストが必要です。

生成中、TFDSは、標準化に元のデータを書き込む.tfrecordファイル。大きなデータセットでは、複数の.tfrecordファイルが作成され、それぞれが複数の例を含みます。私たちは、それぞれの呼び出し.tfrecordシャードファイルを。

このガイドでは、1024個のシャードを持つimagenetを使用しています。

import re
import tensorflow_datasets as tfds

imagenet = tfds.builder('imagenet2012')

num_shards = imagenet.info.splits['train'].num_shards
num_examples = imagenet.info.splits['train'].num_examples
print(f'imagenet has {num_shards} shards ({num_examples} examples)')
imagenet has 1024 shards (1281167 examples)

データセットの例のIDを見つける

決定論についてのみ知りたい場合は、次のセクションにスキップできます。

各データセットの一例は、一意で識別されるid (例えば、 'imagenet2012-train.tfrecord-01023-of-01024__32'あなたは、この回復できるid渡すことでread_config.add_tfds_id = True追加されます'tfds_id'から辞書にキーをtf.data.Dataset

このチュートリアルでは、データセットのサンプルIDを出力する小さなutilを定義します(より人間が読めるように整数に変換されます)。

読むときの決定論

このセクションでは、のdeterministim保証について説明しtfds.load

shuffle_files=False (デフォルト)

デフォルトTFDSが確定例が得られることにより( shuffle_files=False

# Same as: imagenet.as_dataset(split='train').take(20)
print_ex_ids(imagenet, split='train', take=20)
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]

パフォーマンスのために、TFDSは使用して同時に複数の破片を読んtf.data.Dataset.interleaveを。我々はTFDS 16例を読んだ後シャード2に切り替えることは、この例では参照( ..., 14, 15, 1251, 1252, ... )。インターリーブベローズの詳細。

同様に、サブスプリットAPIも決定論的です。

print_ex_ids(imagenet, split='train[67%:84%]', take=20)
print_ex_ids(imagenet, split='train[67%:84%]', take=20)
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]

もし、複数のエポックのためにしている研修場合は、すべてのエポックが同じ順序(ランダム性がに制限されるように破片を読みますと、上記の設定が推奨されていませんds = ds.shuffle(buffer)サイズのバッファ)。

shuffle_files=True

shuffle_files=True読書はもう確定的ではないので、破片は、各エポックのためにシャッフルされています。

print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
[568017, 329050, 329051, 329052, 329053, 329054, 329056, 329055, 568019, 568020, 568021, 568022, 568023, 568018, 568025, 568024, 568026, 568028, 568030, 568031]
[43790, 43791, 43792, 43793, 43796, 43794, 43797, 43798, 43795, 43799, 43800, 43801, 43802, 43803, 43804, 43805, 43806, 43807, 43809, 43810]

確定的なファイルシャッフルを取得するには、以下のレシピを参照してください。

決定論の警告:インターリーブ引数

変更read_config.interleave_cycle_lengthread_config.interleave_block_length例の順序を変更します。

TFDSはに依存しているtf.data.Dataset.interleaveパフォーマンスが向上し、メモリ使用量を減らすこと、一度だけで数破片をロードします。

例の順序は、インターリーブ引数の固定値に対してのみ同じであることが保証されています。参照してくださいインターリーブドキュメントを理解するためにcycle_lengthblock_length 、あまりにも対応して。

  • cycle_length=16block_length=16 (デフォルト、同上)。
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
  • cycle_length=3block_length=2
read_config = tfds.ReadConfig(
    interleave_cycle_length=3,
    interleave_block_length=2,
)
print_ex_ids(imagenet, split='train', read_config=read_config, take=20)
[0, 1, 1251, 1252, 2502, 2503, 2, 3, 1253, 1254, 2504, 2505, 4, 5, 1255, 1256, 2506, 2507, 6, 7]

第二の例では、データセットが2(読み取ることがわかりblock_length=2シャード内)の例を、次の断片に切り替えます。すべて2 * 3( cycle_length=3 )例には、それは最初のシャード(に戻るshard0-ex0, shard0-ex1, shard1-ex0, shard1-ex1, shard2-ex0, shard2-ex1, shard0-ex2, shard0-ex3, shard1-ex2, shard1-ex3, shard2-ex2,... )。

サブスプリットと注文例

各例では、ID有する0, 1, ..., num_examples-1 subsplitのAPIは、 (例えば例のスライスを選択しtrain[:x]選択0, 1, ..., x-1

ただし、サブスプリット内では、例はIDの昇順で読み取られません(シャードとインターリーブのため)。

具体的には、 ds.take(x)split='train[:x]'等価ではありません

これは、例が異なるシャードからのものである上記のインターリーブの例で簡単に確認できます。

print_ex_ids(imagenet, split='train', take=25)  # tfds.load(..., split='train').take(25)
print_ex_ids(imagenet, split='train[:25]', take=-1)  # tfds.load(..., split='train[:25]')
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]

16(block_lengthの)例た後、 .take(25)しながら、次のシャードに切り替わりtrain[:25]最初のシャードからの例を読み続けます。

レシピ

決定論的なファイルシャッフルを取得する

決定論的なシャッフルを行うには、次の2つの方法があります。

  1. 設定shuffle_seed 。注:これには、各エポックでシードを変更する必要があります。変更しない場合、シャードはエポック間で同じ順序で読み取られます。
read_config = tfds.ReadConfig(
    shuffle_seed=32,
)

# Deterministic order, different from the default shuffle_files=False above
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
  1. 使用experimental_interleave_sort_fn :これは、破片が読まれた上で完全に制御できますし、むしろに頼るよりも、どの順序でds.shuffle順。
def _reverse_order(file_instructions):
  return list(reversed(file_instructions))

read_config = tfds.ReadConfig(
    experimental_interleave_sort_fn=_reverse_order,
)

# Last shard (01023-of-01024) is read first
print_ex_ids(imagenet, split='train', read_config=read_config, take=5)
[1279916, 1279917, 1279918, 1279919, 1279920]

決定論的なプリエンプティブパイプラインを取得する

これはもっと複雑です。簡単で満足のいく解決策はありません。

  1. なしでds.shuffleと決定論シャッフルして、理論的には、(の関数としての例は、各シャードで内読み込まれた読みとを推論されている例を数えることが可能であるべきであるcycle_lengthblock_lengthとシャード順)。その後skiptake各シャードについては、を介して注入することができexperimental_interleave_sort_fn

  2. ds.shuffleそれは完全なトレーニングパイプラインを再生することなく、可能性は不可能です。これは、保存が必要となるds.shuffle例が読まれている推測するバッファ状態を。例としては、(例えば、非連続的である可能性がshard5_ex2shard5_ex4読み取りではなくshard5_ex3 )。

  3. ds.shuffle 、一つの方法は、すべてのshards_idsを保存することです/ example_idsは(から推定読みtfds_idそのファイルからの指示を推定します、)。

最も単純なケース1.有することである.skip(x).take(y)マッチtrain[x:x+y]が一致し。必要なもの:

  • セットcycle_length=1 (破片が順次読み出されるように)
  • セットshuffle_files=False
  • 使用しないでくださいds.shuffle

トレーニングが1エポックしかない巨大なデータセットでのみ使用する必要があります。例は、デフォルトのシャッフル順序で読み取られます。

read_config = tfds.ReadConfig(
    interleave_cycle_length=1,  # Read shards sequentially
)

print_ex_ids(imagenet, split='train', read_config=read_config, skip=40, take=22)
# If the job get pre-empted, using the subsplit API will skip at most `len(shard0)`
print_ex_ids(imagenet, split='train[40:]', read_config=read_config, take=22)
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]

特定のサブスプリットに対して読み取られたシャード/例を見つけます

ではtfds.core.DatasetInfo 、あなたが読んで指示に直接アクセスすることができます。

imagenet.info.splits['train[44%:45%]'].file_instructions
[FileInstruction(filename='imagenet2012-train.tfrecord-00450-of-01024', skip=700, take=-1, num_examples=551),
 FileInstruction(filename='imagenet2012-train.tfrecord-00451-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00452-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00453-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00454-of-01024', skip=0, take=-1, num_examples=1252),
 FileInstruction(filename='imagenet2012-train.tfrecord-00455-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00456-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00457-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00458-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00459-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00460-of-01024', skip=0, take=1001, num_examples=1001)]