성능 팁

이 문서는 TensorFlow Datasets(TFDS) 관련 성능 팁을 제공합니다. TFDS는 데이터 세트를 tf.data.Dataset 객체로 제공하므로 tf.data 가이드 의 조언이 계속 적용됩니다.

벤치마크 데이터세트

tfds.benchmark(ds) 를 사용하여 tf.data.Dataset 객체를 벤치마킹합니다.

결과를 정규화하려면 batch_size= 를 지정해야 합니다(예: 100 iter/sec -> 3200 ex/sec). 이것은 모든 iterable에서 작동합니다(예: tfds.benchmark(tfds.as_numpy(ds)) ).

ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)

작은 데이터 세트(1GB 미만)

모든 TFDS 데이터 세트는 TFRecord 형식으로 디스크에 데이터를 저장합니다. 작은 데이터 세트(예: MNIST, CIFAR-10/-100)의 경우 .tfrecord 에서 읽으면 상당한 오버헤드가 추가될 수 있습니다.

이러한 데이터 세트가 메모리에 적합하므로 데이터 세트를 캐싱하거나 사전 로드하여 성능을 크게 향상시킬 수 있습니다. TFDS는 작은 데이터 세트를 자동으로 캐시합니다(자세한 내용은 다음 섹션 참조).

데이터세트 캐싱

다음은 이미지를 정규화한 후 데이터 세트를 명시적으로 캐시하는 데이터 파이프라인의 예입니다.

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label


ds, ds_info = tfds.load(
    'mnist',
    split='train',
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

이 데이터 세트를 반복할 때 캐싱 덕분에 두 번째 반복이 첫 번째 반복보다 훨씬 빠릅니다.

자동 캐싱

기본적으로 TFDS는 다음 제약 조건을 충족하는 데이터 세트 ds.cache() 사용)를 자동 캐시합니다.

  • 총 데이터 세트 크기(모든 분할)가 정의되고 < 250MiB
  • shuffle_files 가 비활성화되었거나 단일 샤드만 읽습니다.

tfds.load의 tfds.ReadConfigtfds.load try_autocaching=False 를 전달하여 자동 캐싱을 옵트아웃할 수 있습니다. 특정 데이터세트가 자동 캐시를 사용할지 알아보려면 데이터세트 카탈로그 문서를 살펴보세요.

전체 데이터를 단일 Tensor로 로드

데이터 세트가 메모리에 맞는 경우 전체 데이터 세트를 단일 Tensor 또는 NumPy 배열로 로드할 수도 있습니다. 단일 tf.Tensor 에서 모든 예제를 일괄 처리하도록 batch_size=-1 을 설정하여 그렇게 할 수 있습니다. 그런 다음 tfds.as_numpy 에서 tf.Tensor 로의 변환을 위해 np.array 를 사용하십시오.

(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
    'mnist',
    split=['train', 'test'],
    batch_size=-1,
    as_supervised=True,
))

대규모 데이터세트

큰 데이터 세트는 분할(여러 파일로 분할)되며 일반적으로 메모리에 맞지 않으므로 캐시하면 안 됩니다.

셔플 및 훈련

훈련 중에는 데이터를 잘 섞는 것이 중요합니다. 데이터를 잘못 섞으면 훈련 정확도가 낮아질 수 있습니다.

ds.shuffle 을 사용하여 레코드를 섞는 것 외에도 여러 파일로 분할되는 더 큰 데이터 세트에 대해 좋은 섞기 동작을 얻으려면 shuffle_files=True 를 설정해야 합니다. 그렇지 않으면 에포크가 동일한 순서로 샤드를 읽으므로 데이터가 실제로 무작위로 지정되지 않습니다.

ds = tfds.load('imagenet2012', split='train', shuffle_files=True)

또한 shuffle_files=True 인 경우 TFDS는 options.deterministic 을 비활성화하여 성능을 약간 향상시킬 수 있습니다. 결정적 셔플링을 얻으려면 tfds.ReadConfig 를 설정하거나 read_config.shuffle_seed 을 덮어 read_config.options.deterministic 를 사용하여 이 기능을 옵트아웃할 수 있습니다.

작업자 간에 데이터 자동 샤딩(TF)

여러 작업자에 대해 교육할 때 tfds.ReadConfiginput_context 인수를 사용할 수 있으므로 각 작업자는 데이터의 하위 집합을 읽습니다.

input_context = tf.distribute.InputContext(
    input_pipeline_id=1,  # Worker id
    num_input_pipelines=4,  # Total number of workers
)
read_config = tfds.ReadConfig(
    input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)

이는 subsplit API를 보완합니다. 먼저 subplit API가 적용됩니다. train[:50%] 을 읽을 파일 목록으로 변환합니다. 그런 다음 해당 파일에 ds.shard() 이 적용됩니다. 예를 들어 num_input_pipelines=2 와 함께 train[:50%] 을 사용할 때 2명의 작업자는 각각 데이터의 1/4을 읽습니다.

shuffle_files=True 인 경우 파일은 한 작업자 내에서 섞이지만 작업자 간에는 섞이지 않습니다. 각 작업자는 Epoch 간에 동일한 파일 하위 집합을 읽습니다.

작업자 간에 데이터 자동 샤딩(Jax)

Jax를 사용하면 tfds.split_for_jax_process 또는 tfds.even_splits API를 사용하여 작업자 간에 데이터를 배포할 수 있습니다. 분할 API 가이드를 참조하세요.

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_process 는 다음의 간단한 별칭입니다.

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

더 빠른 이미지 디코딩

기본적으로 TFDS는 이미지를 자동으로 디코딩합니다. 그러나 tfds.decode.SkipDecoding 으로 이미지 디코딩을 건너뛰고 tf.io.decode_image 연산을 수동으로 적용하는 것이 더 성능이 좋은 경우가 있습니다.

두 예제에 대한 코드는 디코딩 가이드 에서 사용할 수 있습니다.

사용하지 않는 기능 건너뛰기

기능의 하위 집합만 사용하는 경우 일부 기능을 완전히 건너뛸 수 있습니다. 데이터세트에 사용하지 않은 기능이 많은 경우 디코딩하지 않으면 성능이 크게 향상될 수 있습니다. https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features 참조