Splits et tranchage

Tous les ensembles de données TFDS exposent diverses divisions de données (par exemple 'train' , 'test' ) qui peuvent être explorées dans le catalogue .

En plus des divisions "officielles" des ensembles de données, TFDS permet de sélectionner des tranches de division(s) et diverses combinaisons.

API de découpage

Les instructions de découpage sont spécifiées dans tfds.load ou tfds.DatasetBuilder.as_dataset via split= kwarg.

ds = tfds.load('my_dataset', split='train[:75%]')
builder = tfds.builder('my_dataset')
ds = builder.as_dataset(split='test+train[:75%]')

Le fractionnement peut être :

  • Division simple ( 'train' , 'test' ) : tous les exemples dans la division sélectionnée.
  • Tranches : Les tranches ont la même sémantique que la notation de tranche python . Les tranches peuvent être :
    • Absolute ( 'train[123:450]' , train[:4000] ): (voir la note ci-dessous pour la mise en garde concernant l'ordre de lecture)
    • Pourcentage ( 'train[:75%]' , 'train[25%:75%]' ) : Divisez les données complètes en 100 tranches paires. Si les données ne sont pas divisibles par 100, certains pourcentages peuvent contenir des exemples supplémentaires.
    • Shard ( train[:4shard] , train[4shard] ) : sélectionnez tous les exemples dans le fragment demandé. (voir info.splits['train'].num_shards pour obtenir le nombre de fragments du split)
  • Union de divisions ( 'train+test' , 'train[:25%]+test' ) : les divisions seront entrelacées.
  • Ensemble de données complet ( 'all' ) : 'all' est un nom de division spécial correspondant à l'union de toutes les divisions (équivalent à 'train+test+...' ).
  • Liste des fractionnements ( ['train', 'test'] ) : plusieurs tf.data.Dataset sont renvoyés séparément :
# Returns both train and test split separately
train_ds, test_ds = tfds.load('mnist', split=['train', 'test[:50%]'])

tfds.even_splits & formation multi-hôtes

tfds.even_splits génère une liste de sous-splits non superposés de même taille.

# Divide the dataset into 3 even parts, each containing 1/3 of the data
split0, split1, split2 = tfds.even_splits('train', n=3)

ds = tfds.load('my_dataset', split=split2)

Cela peut être particulièrement utile lors de la formation dans un environnement distribué, où chaque hôte doit recevoir une tranche des données d'origine.

Avec Jax , cela peut être simplifié encore plus en utilisant tfds.split_for_jax_process :

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_process est un simple alias pour :

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

tfds.even_splits , tfds.split_for_jax_process accepte n'importe quelle valeur fractionnée en entrée (par exemple 'train[75%:]+test' )

Découpage et métadonnées

Il est possible d'obtenir des informations supplémentaires sur les divisions/sous-divisions ( num_examples , file_instructions ,...) en utilisant les infos du jeu de données :

builder = tfds.builder('my_dataset')
builder.info.splits['train'].num_examples  # 10_000
builder.info.splits['train[:75%]'].num_examples  # 7_500 (also works with slices)
builder.info.splits.keys()  # ['train', 'test']

Validation croisée

Exemples de validation croisée 10 fois à l'aide de l'API de chaîne :

vals_ds = tfds.load('mnist', split=[
    f'train[{k}%:{k+10}%]' for k in range(0, 100, 10)
])
trains_ds = tfds.load('mnist', split=[
    f'train[:{k}%]+train[{k+10}%:]' for k in range(0, 100, 10)
])

Les ensembles de données de validation seront chacun de 10 % : [0%:10%] , [10%:20%] , ..., [90%:100%] . Et les ensembles de données d'entraînement vont chacun être les 90% complémentaires : [10%:100%] (pour un ensemble de validation correspondant de [0%:10%] ), `[0%:10%]

  • [20%:100%] (for a validation set of [10%:20%]`),...

tfds.core.ReadInstruction et arrondi

Plutôt que str , il est possible de passer des fractionnements en tant que tfds.core.ReadInstruction :

Par exemple, split = 'train[50%:75%] + test' équivaut à :

split = (
    tfds.core.ReadInstruction(
        'train',
        from_=50,
        to=75,
        unit='%',
    )
    + tfds.core.ReadInstruction('test')
)
ds = tfds.load('my_dataset', split=split)

unit peut être :

  • abs : Tranchage absolu
  • % : Pourcentage de découpage
  • shard : découpage en fragments

tfds.ReadInstruction a également un argument d'arrondi. Si le nombre d'exemples dans le jeu de données n'est pas divisé par 100 :

  • rounding='closest' (par défaut) : les exemples restants sont répartis entre les pourcentages, de sorte que certains pourcentages peuvent contenir des exemples supplémentaires.
  • rounding='pct1_dropremainder' : les exemples restants sont supprimés, mais cela garantit que tous les pourcentages contiennent exactement le même nombre d'exemples (par exemple : len(5%) == 5 * len(1%) ).

Reproductibilité & déterminisme

Lors de la génération, pour une version de jeu de données donnée, TFDS garantit que les exemples sont mélangés de manière déterministe sur le disque. Ainsi, générer le jeu de données deux fois (sur 2 ordinateurs différents) ne changera pas l'ordre de l'exemple.

De même, l'API subsplit sélectionnera toujours le même set d'exemples, quelle que soit la plate-forme, l'architecture, etc. Cela signifie set('train[:20%]') == set('train[:10%]') + set('train[10%:20%]') .

Cependant, l'ordre dans lequel les exemples sont lus peut ne pas être déterministe. Cela dépend d'autres paramètres (par exemple si shuffle_files=True ).