TFDS dan determinisme

Lihat di TensorFlow.org Jalankan di Google Colab Lihat di GitHub Unduh buku catatan

Dokumen ini menjelaskan:

  • TFDS menjamin determinisme
  • Dalam urutan apa TFDS membaca contoh?
  • Berbagai peringatan dan gotcha

Mempersiapkan

Kumpulan data

Beberapa konteks diperlukan untuk memahami bagaimana TFDS membaca data.

Selama generasi, TFDS menulis data asli ke standar .tfrecord file. Untuk dataset besar, beberapa .tfrecord file dibuat, masing-masing berisi beberapa contoh. Kami menyebutnya setiap .tfrecord mengajukan beling a.

Panduan ini menggunakan imagenet yang memiliki 1024 pecahan:

import re
import tensorflow_datasets as tfds

imagenet = tfds.builder('imagenet2012')

num_shards = imagenet.info.splits['train'].num_shards
num_examples = imagenet.info.splits['train'].num_examples
print(f'imagenet has {num_shards} shards ({num_examples} examples)')
imagenet has 1024 shards (1281167 examples)

Menemukan id contoh kumpulan data

Anda dapat melompat ke bagian berikut jika Anda hanya ingin tahu tentang determinisme.

Setiap contoh dataset secara unik diidentifikasi oleh id (misalnya 'imagenet2012-train.tfrecord-01023-of-01024__32' ). Anda dapat memulihkan ini id dengan melewati read_config.add_tfds_id = True yang akan menambah 'tfds_id' kunci dalam dict dari tf.data.Dataset .

Dalam tutorial ini, kami mendefinisikan util kecil yang akan mencetak contoh id dari dataset (dikonversi dalam bilangan bulat agar lebih mudah dibaca manusia):

Determinisme saat membaca

Bagian ini menjelaskan jaminan deterministim dari tfds.load .

Dengan shuffle_files=False (default)

Dengan TFDS standar menghasilkan contoh deterministik ( shuffle_files=False )

# Same as: imagenet.as_dataset(split='train').take(20)
print_ex_ids(imagenet, split='train', take=20)
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]

Untuk kinerja, TFDS membaca beberapa pecahan pada saat yang sama menggunakan tf.data.Dataset.interleave . Kita melihat dalam contoh ini yang TFDS beralih ke pecahan 2 setelah membaca 16 contoh ( ..., 14, 15, 1251, 1252, ... ). Lebih lanjut tentang interleave di bawah.

Demikian pula, API subsplit juga deterministik:

print_ex_ids(imagenet, split='train[67%:84%]', take=20)
print_ex_ids(imagenet, split='train[67%:84%]', take=20)
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]

Jika Anda pelatihan sedang selama lebih dari satu zaman, setup di atas tidak dianjurkan karena semua zaman akan membaca pecahan dalam urutan yang sama (sehingga keacakan terbatas pada ds = ds.shuffle(buffer) ukuran buffer).

Dengan shuffle_files=True

Dengan shuffle_files=True , pecahan dikocok untuk setiap zaman, sehingga membaca tidak deterministik lagi.

print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
[568017, 329050, 329051, 329052, 329053, 329054, 329056, 329055, 568019, 568020, 568021, 568022, 568023, 568018, 568025, 568024, 568026, 568028, 568030, 568031]
[43790, 43791, 43792, 43793, 43796, 43794, 43797, 43798, 43795, 43799, 43800, 43801, 43802, 43803, 43804, 43805, 43806, 43807, 43809, 43810]

Lihat resep di bawah ini untuk mendapatkan pengocokan file deterministik.

Peringatan determinisme: argumen interleave

Mengubah read_config.interleave_cycle_length , read_config.interleave_block_length akan mengubah urutan contoh.

TFDS bergantung pada tf.data.Dataset.interleave hanya memuat beberapa pecahan sekaligus, meningkatkan kinerja dan mengurangi penggunaan memori.

Urutan contoh hanya dijamin sama untuk nilai tetap dari argumen interleave. Lihat doc interleave untuk memahami apa yang cycle_length dan block_length bersesuaian juga.

  • cycle_length=16 , block_length=16 (default, sama seperti di atas):
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
  • cycle_length=3 , block_length=2 :
read_config = tfds.ReadConfig(
    interleave_cycle_length=3,
    interleave_block_length=2,
)
print_ex_ids(imagenet, split='train', read_config=read_config, take=20)
[0, 1, 1251, 1252, 2502, 2503, 2, 3, 1253, 1254, 2504, 2505, 4, 5, 1255, 1256, 2506, 2507, 6, 7]

Dalam contoh kedua, kita melihat bahwa dataset membaca 2 ( block_length=2 ) contoh dalam beling, kemudian beralih ke pecahan berikutnya. Setiap 2 * 3 ( cycle_length=3 ) contoh, ia pergi kembali ke pecahan pertama ( shard0-ex0, shard0-ex1, shard1-ex0, shard1-ex1, shard2-ex0, shard2-ex1, shard0-ex2, shard0-ex3, shard1-ex2, shard1-ex3, shard2-ex2,... ).

Subsplit dan contoh pesanan

Setiap contoh memiliki id 0, 1, ..., num_examples-1 . The subsplit API pilih sepotong contoh (misalnya train[:x] pilih 0, 1, ..., x-1 ).

Namun, dalam subsplit, contoh tidak dibaca dalam urutan id yang meningkat (karena pecahan dan interleave).

Lebih khusus, ds.take(x) dan split='train[:x]' tidak setara!

Ini dapat dilihat dengan mudah pada contoh interleave di atas di mana contoh berasal dari pecahan yang berbeda.

print_ex_ids(imagenet, split='train', take=25)  # tfds.load(..., split='train').take(25)
print_ex_ids(imagenet, split='train[:25]', take=-1)  # tfds.load(..., split='train[:25]')
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]

Setelah 16 (block_length) contoh, .take(25) beralih ke beling berikutnya sementara train[:25] melanjutkan membaca contoh dalam dari beling pertama.

resep

Dapatkan pengacakan file deterministik

Ada 2 cara untuk melakukan pengocokan deterministik:

  1. Mengatur shuffle_seed . Catatan: Ini membutuhkan perubahan benih di setiap epoch, jika tidak shard akan dibaca dalam urutan yang sama di antara epoch.
read_config = tfds.ReadConfig(
    shuffle_seed=32,
)

# Deterministic order, different from the default shuffle_files=False above
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
  1. Menggunakan experimental_interleave_sort_fn : Ini memberikan kontrol penuh atas yang pecahan dibaca dan urutannya, daripada mengandalkan ds.shuffle pesanan.
def _reverse_order(file_instructions):
  return list(reversed(file_instructions))

read_config = tfds.ReadConfig(
    experimental_interleave_sort_fn=_reverse_order,
)

# Last shard (01023-of-01024) is read first
print_ex_ids(imagenet, split='train', read_config=read_config, take=5)
[1279916, 1279917, 1279918, 1279919, 1279920]

Dapatkan pipa preemptable deterministik

Yang ini lebih rumit. Tidak ada solusi yang mudah dan memuaskan.

  1. Tanpa ds.shuffle dan dengan menyeret deterministik, dalam teori itu harus mungkin untuk menghitung contoh yang telah membaca dan menyimpulkan yang contoh telah membaca dalam di setiap beling (sebagai fungsi cycle_length , block_length dan ketertiban beling). Kemudian skip , take untuk setiap pecahan dapat disuntikkan melalui experimental_interleave_sort_fn .

  2. Dengan ds.shuffle kemungkinan mustahil tanpa mengulang pipa pelatihan penuh. Ini akan membutuhkan menyimpan ds.shuffle negara penyangga untuk menyimpulkan yang contoh telah membaca. Contoh bisa menjadi non-kontinyu (misalnya shard5_ex2 , shard5_ex4 membaca tapi tidak shard5_ex3 ).

  3. Dengan ds.shuffle , salah satu cara adalah dengan menyimpan semua shards_ids / example_ids baca (disimpulkan dari tfds_id ), kemudian menyimpulkan petunjuk file dari itu.

Kasus yang paling sederhana untuk 1. adalah memiliki .skip(x).take(y) pertandingan train[x:x+y] pertandingan. Ini membutuhkan:

  • Set cycle_length=1 (sehingga pecahan dibaca berurutan)
  • Set shuffle_files=False
  • Jangan gunakan ds.shuffle

Seharusnya hanya digunakan pada dataset besar di mana pelatihan hanya 1 epoch. Contoh akan dibaca dalam urutan acak default.

read_config = tfds.ReadConfig(
    interleave_cycle_length=1,  # Read shards sequentially
)

print_ex_ids(imagenet, split='train', read_config=read_config, skip=40, take=22)
# If the job get pre-empted, using the subsplit API will skip at most `len(shard0)`
print_ex_ids(imagenet, split='train[40:]', read_config=read_config, take=22)
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]

Temukan pecahan/contoh mana yang dibaca untuk subsplit yang diberikan

Dengan tfds.core.DatasetInfo , Anda memiliki akses langsung ke instruksi dibaca.

imagenet.info.splits['train[44%:45%]'].file_instructions
[FileInstruction(filename='imagenet2012-train.tfrecord-00450-of-01024', skip=700, take=-1, num_examples=551),
 FileInstruction(filename='imagenet2012-train.tfrecord-00451-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00452-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00453-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00454-of-01024', skip=0, take=-1, num_examples=1252),
 FileInstruction(filename='imagenet2012-train.tfrecord-00455-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00456-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00457-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00458-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00459-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00460-of-01024', skip=0, take=1001, num_examples=1001)]