Ce document fournit des conseils de performances spécifiques aux ensembles de données TensorFlow (TFDS). Notez que TFDS fournit des ensembles de données en tant tf.data.Dataset
, de sorte que les conseils du guide tf.data
s'appliquent toujours.
Ensembles de données de référence
Utilisez tfds.benchmark(ds)
pour comparer n'importe quel objet tf.data.Dataset
.
Assurez-vous d'indiquer le batch_size=
pour normaliser les résultats (par exemple 100 iter/sec -> 3200 ex/sec). Cela fonctionne avec n'importe quel itérable (par exemple tfds.benchmark(tfds.as_numpy(ds))
).
ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)
Petits ensembles de données (moins de 1 Go)
Tous les jeux de données TFDS stockent les données sur disque au format TFRecord
. Pour les petits ensembles de données (par exemple, MNIST, CIFAR-10/-100), la lecture à partir de .tfrecord
peut ajouter une surcharge importante.
Comme ces ensembles de données tiennent en mémoire, il est possible d'améliorer considérablement les performances en mettant en cache ou en préchargeant l'ensemble de données. Notez que TFDS met automatiquement en cache les petits ensembles de données (la section suivante contient les détails).
Mise en cache du jeu de données
Voici un exemple de pipeline de données qui met explicitement en cache l'ensemble de données après la normalisation des images.
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., label
ds, ds_info = tfds.load(
'mnist',
split='train',
as_supervised=True, # returns `(img, label)` instead of dict(image=, ...)
with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
Lors de l'itération sur ce jeu de données, la deuxième itération sera beaucoup plus rapide que la première grâce à la mise en cache.
Mise en cache automatique
Par défaut, TFDS met automatiquement en cache (avec ds.cache()
) les ensembles de données qui satisfont aux contraintes suivantes :
- La taille totale de l'ensemble de données (toutes les divisions) est définie et < 250 Mio
-
shuffle_files
est désactivé ou un seul fragment est lu
Il est possible de désactiver la mise en cache automatique en transmettant try_autocaching=False
à tfds.ReadConfig
dans tfds.load
. Consultez la documentation du catalogue de jeux de données pour voir si un jeu de données spécifique utilisera la mise en cache automatique.
Chargement des données complètes en tant que Tensor unique
Si votre jeu de données tient dans la mémoire, vous pouvez également charger le jeu de données complet en tant que tableau Tensor ou NumPy unique. Il est possible de le faire en définissant batch_size=-1
pour grouper tous les exemples dans un seul tf.Tensor
. Utilisez ensuite tfds.as_numpy
pour la conversion de tf.Tensor
en np.array
.
(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
'mnist',
split=['train', 'test'],
batch_size=-1,
as_supervised=True,
))
Grands ensembles de données
Les ensembles de données volumineux sont fragmentés (divisés en plusieurs fichiers) et ne tiennent généralement pas dans la mémoire, ils ne doivent donc pas être mis en cache.
Shuffle et entraînement
Pendant l'entraînement, il est important de bien mélanger les données - des données mal mélangées peuvent entraîner une précision d'entraînement moindre.
En plus d'utiliser ds.shuffle
pour mélanger les enregistrements, vous devez également définir shuffle_files=True
pour obtenir un bon comportement de mélange pour les ensembles de données plus volumineux qui sont divisés en plusieurs fichiers. Sinon, les époques liront les fragments dans le même ordre, et les données ne seront donc pas vraiment randomisées.
ds = tfds.load('imagenet2012', split='train', shuffle_files=True)
De plus, lorsque shuffle_files=True
, TFDS désactive options.deterministic
, ce qui peut améliorer légèrement les performances. Pour obtenir un brassage déterministe, il est possible de désactiver cette fonctionnalité avec tfds.ReadConfig
: soit en définissant read_config.shuffle_seed
, soit en écrasant read_config.options.deterministic
.
Partage automatique de vos données entre les travailleurs (TF)
Lors de la formation sur plusieurs nœuds de calcul, vous pouvez utiliser l'argument input_context
de tfds.ReadConfig
, afin que chaque nœud de calcul lise un sous-ensemble des données.
input_context = tf.distribute.InputContext(
input_pipeline_id=1, # Worker id
num_input_pipelines=4, # Total number of workers
)
read_config = tfds.ReadConfig(
input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)
Ceci est complémentaire à l'API subsplit. Tout d'abord, l'API subplit est appliquée : train[:50%]
est converti en une liste de fichiers à lire. Ensuite, une ds.shard()
est appliquée sur ces fichiers. Par exemple, lors de l'utilisation train[:50%]
avec num_input_pipelines=2
, chacun des 2 travailleurs lira 1/4 des données.
Lorsque shuffle_files=True
, les fichiers sont mélangés au sein d'un travailleur, mais pas entre les travailleurs. Chaque travailleur lira le même sous-ensemble de fichiers entre les époques.
Partage automatique de vos données entre les travailleurs (Jax)
Avec Jax, vous pouvez utiliser l'API tfds.split_for_jax_process
ou tfds.even_splits
pour distribuer vos données entre les nœuds de calcul. Voir le guide de l'API fractionnée .
split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)
tfds.split_for_jax_process
est un simple alias pour :
# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]
Décodage d'image plus rapide
Par défaut, TFDS décode automatiquement les images. Cependant, il existe des cas où il peut être plus performant d'ignorer le décodage d'image avec tfds.decode.SkipDecoding
et d'appliquer manuellement l'op tf.io.decode_image
:
- Lors du filtrage d'exemples (avec
tf.data.Dataset.filter
), pour décoder les images après le filtrage des exemples. - Lors du recadrage d'images, utilisez l'opération fusionnée
tf.image.decode_and_crop_jpeg
.
Le code des deux exemples est disponible dans le guide de décodage .
Ignorer les fonctionnalités inutilisées
Si vous n'utilisez qu'un sous-ensemble de fonctionnalités, il est possible d'ignorer complètement certaines fonctionnalités. Si votre jeu de données contient de nombreuses fonctionnalités inutilisées, ne pas les décoder peut améliorer considérablement les performances. Voir https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features