Cette page a été traduite par l'API Cloud Translation.
Switch to English

TFRecord et tf.train.Example

Voir sur TensorFlow.org Exécuter dans Google Colab Afficher la source sur GitHub Télécharger le carnet

Le format TFRecord est un format simple pour stocker une séquence d'enregistrements binaires.

Les tampons de protocole sont une bibliothèque multiplateforme et multilingue pour une sérialisation efficace des données structurées.

Les messages de protocole sont définis par des fichiers .proto , ce sont souvent le moyen le plus simple de comprendre un type de message.

Le message tf.train.Example (ou protobuf) est un type de message flexible qui représente un mappage {"string": value} . Il est conçu pour être utilisé avec TensorFlow et est utilisé dans toutes les API de niveau supérieur telles que TFX .

Ce bloc-notes vous montrera comment créer, analyser et utiliser le message tf.train.Example , puis sérialiser, écrire et lire les messages tf.train.Example vers et à partir de fichiers .tfrecord .

Installer

import tensorflow as tf

import numpy as np
import IPython.display as display

tf.train.Example

Types de données pour tf.train.Example

Fondamentalement, un tf.train.Example est un tf.train.Example {"string": tf.train.Feature} .

Le type de message tf.train.Feature peut accepter l'un des trois types suivants (voir le fichier .proto pour référence). La plupart des autres types génériques peuvent être contraints à l'un de ces types:

  1. tf.train.BytesList (les types suivants peuvent être forcés)

    • string
    • byte
  2. tf.train.FloatList (les types suivants peuvent être forcés)

    • float ( float32 )
    • double ( float64 )
  3. tf.train.Int64List (les types suivants peuvent être forcés)

    • bool
    • enum
    • int32
    • uint32
    • int64
    • uint64

Afin de convertir un type TensorFlow standard en un tf.train.Example -compatible tf.train.Feature , vous pouvez utiliser les fonctions de raccourci ci-dessous. Notez que chaque fonction prend une valeur d'entrée scalaire et renvoie un tf.train.Feature contenant l'un des trois types de list ci-dessus:

# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

Voici quelques exemples du fonctionnement de ces fonctions. Notez les différents types d'entrée et les types de sortie normalisés. Si le type d'entrée d'une fonction ne correspond pas à l'un des types coercibles indiqués ci-dessus, la fonction _int64_feature(1.0) une exception (par exemple, _int64_feature(1.0) produira une erreur, puisque 1.0 est un flottant, elle doit donc être utilisée avec la fonction _float_feature place) :

print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))

print(_float_feature(np.exp(1)))

print(_int64_feature(True))
print(_int64_feature(1))
bytes_list {
  value: "test_string"
}

bytes_list {
  value: "test_bytes"
}

float_list {
  value: 2.7182817459106445
}

int64_list {
  value: 1
}

int64_list {
  value: 1
}


Tous les messages proto peuvent être sérialisés en une chaîne binaire à l'aide de la méthode .SerializeToString :

feature = _float_feature(np.exp(1))

feature.SerializeToString()
b'\x12\x06\n\x04T\xf8-@'

Création d'un message tf.train.Example

Supposons que vous souhaitiez créer un message tf.train.Example partir de données existantes. En pratique, l'ensemble de données peut provenir de n'importe où, mais la procédure de création du message tf.train.Example partir d'une seule observation sera la même:

  1. Dans chaque observation, chaque valeur doit être convertie en un tf.train.Feature contenant l'un des 3 types compatibles, en utilisant l'une des fonctions ci-dessus.

  2. Vous créez une carte (dictionnaire) à partir de la chaîne de nom d'entité vers la valeur d'entité codée produite en # 1.

  3. La carte produite à l'étape 2 est convertie en un message Features .

Dans ce bloc-notes, vous allez créer un ensemble de données à l'aide de NumPy.

Cet ensemble de données aura 4 fonctionnalités:

  • une fonction booléenne, False ou True avec une probabilité égale
  • une caractéristique entière uniformément choisie au hasard parmi [0, 5]
  • une fonction de chaîne générée à partir d'une table de chaînes en utilisant la fonction d'entier comme index
  • une fonction flottante d'une distribution normale standard

Considérons un échantillon composé de 10000 observations réparties indépendamment et de manière identique de chacune des distributions ci-dessus:

# The number of observations in the dataset.
n_observations = int(1e4)

# Boolean feature, encoded as False or True.
feature0 = np.random.choice([False, True], n_observations)

# Integer feature, random from 0 to 4.
feature1 = np.random.randint(0, 5, n_observations)

# String feature
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]

# Float feature, from a standard normal distribution
feature3 = np.random.randn(n_observations)

Chacune de ces fonctionnalités peut être forcée dans un type compatible tf.train.Example utilisant l'une des fonctions _bytes_feature , _float_feature , _int64_feature . Vous pouvez ensuite créer un message tf.train.Example partir de ces fonctionnalités encodées:

def serialize_example(feature0, feature1, feature2, feature3):
  """
  Creates a tf.train.Example message ready to be written to a file.
  """
  # Create a dictionary mapping the feature name to the tf.train.Example-compatible
  # data type.
  feature = {
      'feature0': _int64_feature(feature0),
      'feature1': _int64_feature(feature1),
      'feature2': _bytes_feature(feature2),
      'feature3': _float_feature(feature3),
  }

  # Create a Features message using tf.train.Example.

  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

Par exemple, supposons que vous ayez une seule observation de l'ensemble de données, [False, 4, bytes('goat'), 0.9876] . Vous pouvez créer et imprimer le message tf.train.Example pour cette observation en utilisant create_message() . Chaque observation unique sera écrite sous la forme d'un message de Features comme indiqué ci-dessus. Notez que le message tf.train.Example est juste un wrapper autour du message Features :

# This is an example observation from the dataset.

example_observation = []

serialized_example = serialize_example(False, 4, b'goat', 0.9876)
serialized_example
b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat'

Pour décoder le message, utilisez la méthode tf.train.Example.FromString .

example_proto = tf.train.Example.FromString(serialized_example)
example_proto
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 4
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "goat"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.9876000285148621
      }
    }
  }
}

Détails du format TFRecords

Un fichier TFRecord contient une séquence d'enregistrements. Le fichier ne peut être lu que séquentiellement.

Chaque enregistrement contient une chaîne d'octets, pour la charge de données, plus la longueur des données, et CRC32C (CRC 32 bits utilisant le polynôme de Castagnoli) pour la vérification de l'intégrité.

Chaque enregistrement est stocké dans les formats suivants:

uint64 length
uint32 masked_crc32_of_length
byte   data[length]
uint32 masked_crc32_of_data

Les enregistrements sont concaténés pour produire le fichier. Les CRC sont décrits ici , et le masque d'un CRC est:

masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul

Fichiers tf.data utilisant tf.data

Le module tf.data fournit également des outils pour lire et écrire des données dans TensorFlow.

Ecrire un fichier TFRecord

Le moyen le plus simple d'obtenir les données dans un ensemble de données est d'utiliser la méthode from_tensor_slices .

Appliqué à un tableau, il renvoie un ensemble de données de scalaires:

tf.data.Dataset.from_tensor_slices(feature1)
<TensorSliceDataset shapes: (), types: tf.int64>

Appliqué à un tuple de tableaux, il renvoie un ensemble de données de tuples:

features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))
features_dataset
<TensorSliceDataset shapes: ((), (), (), ()), types: (tf.bool, tf.int64, tf.string, tf.float64)>
# Use `take(1)` to only pull one example from the dataset.
for f0,f1,f2,f3 in features_dataset.take(1):
  print(f0)
  print(f1)
  print(f2)
  print(f3)
tf.Tensor(False, shape=(), dtype=bool)
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(b'cat', shape=(), dtype=string)
tf.Tensor(-0.07564599618591197, shape=(), dtype=float64)

Utilisez la méthode tf.data.Dataset.map pour appliquer une fonction à chaque élément d'un Dataset .

La fonction mappée doit fonctionner en mode graphique TensorFlow - elle doit fonctionner sur et renvoyer tf.Tensors . Une fonction non tensorielle, comme serialize_example , peut être tf.py_function avec tf.py_function pour la rendre compatible.

L'utilisation de tf.py_function nécessite de spécifier les informations de forme et de type qui ne tf.py_function pas disponibles autrement:

def tf_serialize_example(f0,f1,f2,f3):
  tf_string = tf.py_function(
    serialize_example,
    (f0,f1,f2,f3),  # pass these args to the above function.
    tf.string)      # the return type is `tf.string`.
  return tf.reshape(tf_string, ()) # The result is a scalar
tf_serialize_example(f0,f1,f2,f3)
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04J\xec\x9a\xbd\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00'>

Appliquez cette fonction à chaque élément de l'ensemble de données:

serialized_features_dataset = features_dataset.map(tf_serialize_example)
serialized_features_dataset
<MapDataset shapes: (), types: tf.string>
def generator():
  for features in features_dataset:
    yield serialize_example(*features)
serialized_features_dataset = tf.data.Dataset.from_generator(
    generator, output_types=tf.string, output_shapes=())
serialized_features_dataset
<FlatMapDataset shapes: (), types: tf.string>

Et écrivez-les dans un fichier TFRecord:

filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)

Lire un fichier TFRecord

Vous pouvez également lire le fichier TFRecord à l'aide de la classe tf.data.TFRecordDataset .

Plus d'informations sur la consommation de fichiers TFRecord à l'aide de tf.data peuvent être trouvées ici .

L'utilisation de TFRecordDataset s peut être utile pour normaliser les données d'entrée et optimiser les performances.

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>

À ce stade, l'ensemble de données contient des messages tf.train.Example sérialisés. Lorsqu'il est itéré, il les renvoie sous forme de tenseurs de chaîne scalaires.

Utilisez la méthode .take pour n'afficher que les 10 premiers enregistrements.

for raw_record in raw_dataset.take(10):
  print(repr(raw_record))
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04J\xec\x9a\xbd\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nS\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x1f\xe0\xcb?\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04q\xa9\xb8>\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x93|+?\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xa0X}?\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\x19\x11\xc0\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nS\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x14\n\x08feature3\x12\x08\x12\x06\n\x0473\x12>\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xa2\xf7\xf9>'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xba\xf8\xb1?'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xb71\xe5>\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat'>

Ces tenseurs peuvent être analysés à l'aide de la fonction ci-dessous. Notez que feature_description est nécessaire ici car les ensembles de données utilisent l'exécution de graphes et ont besoin de cette description pour construire leur signature de forme et de type:

# Create a description of the features.
feature_description = {
    'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
}

def _parse_function(example_proto):
  # Parse the input `tf.train.Example` proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, feature_description)

Vous pouvez également utiliser l' tf.parse example pour analyser l'ensemble du lot en une seule fois. Appliquez cette fonction à chaque élément de l'ensemble de données à l'aide de la méthode tf.data.Dataset.map :

parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset
<MapDataset shapes: {feature0: (), feature1: (), feature2: (), feature3: ()}, types: {feature0: tf.int64, feature1: tf.int64, feature2: tf.string, feature3: tf.float32}>

Utilisez une exécution impatiente pour afficher les observations dans l'ensemble de données. Il y a 10 000 observations dans ce jeu de données, mais vous n'afficherez que les 10 premières. Les données sont affichées sous forme de dictionnaire d'entités. Chaque élément est un tf.Tensor , et l'élément numpy de ce tenseur affiche la valeur de la fonction:

for parsed_record in parsed_dataset.take(10):
  print(repr(parsed_record))
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.075646>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=3>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'horse'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=1.5927771>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.36066774>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.6698696>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.98963356>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-2.2671726>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=3>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'horse'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.1427735>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'dog'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.4882174>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=1.390403>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.44764492>}

Ici, la fonction tf.parse_example tf.train.Example champs tf.train.Example en tenseurs standard.

Fichiers TFRecord en Python

Le module tf.io contient également des fonctions pur-Python pour la lecture et l'écriture de fichiers TFRecord.

Ecrire un fichier TFRecord

Ensuite, écrivez les 10 000 observations dans le fichier test.tfrecord . Chaque observation est convertie en un message tf.train.Example , puis écrite dans un fichier. Vous pouvez ensuite vérifier que le fichier test.tfrecord a bien été créé:

# Write the `tf.train.Example` observations to the file.
with tf.io.TFRecordWriter(filename) as writer:
  for i in range(n_observations):
    example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
    writer.write(example)
du -sh {filename}
984K    test.tfrecord

Lire un fichier TFRecord

Ces tenseurs sérialisés peuvent être facilement analysés à l'aide de tf.train.Example.ParseFromString :

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>
for raw_record in raw_dataset.take(1):
  example = tf.train.Example()
  example.ParseFromString(raw_record.numpy())
  print(example)
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "cat"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: -0.07564599812030792
      }
    }
  }
}


Procédure pas à pas: lecture et écriture de données d'image

Ceci est un exemple de bout en bout de la façon de lire et d'écrire des données d'image à l'aide de TFRecords. En utilisant une image comme données d'entrée, vous allez écrire les données sous forme de fichier TFRecord, puis relire le fichier et afficher l'image.

Cela peut être utile si, par exemple, vous souhaitez utiliser plusieurs modèles sur le même jeu de données d'entrée. Au lieu de stocker les données d'image brutes, elles peuvent être prétraitées au format TFRecords, et peuvent être utilisées dans tous les traitements et modélisations ultérieurs.

Commençons par télécharger cette image d'un chat dans la neige et cette photo du pont de Williamsburg à New York en construction.

Récupérez les images

cat_in_snow  = tf.keras.utils.get_file('320px-Felis_catus-cat_on_snow.jpg', 'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file('194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg','https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg
24576/17858 [=========================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg
16384/15477 [===============================] - 0s 0us/step

display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))

jpeg

display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">From Wikimedia</a>'))

jpeg

Ecrire le fichier TFRecord

Comme précédemment, encodez les fonctionnalités en tant que types compatibles avec tf.train.Example . Cela stocke la fonction de chaîne d'image brute, ainsi que la fonction de hauteur, largeur, profondeur et label arbitraire. Ce dernier est utilisé lorsque vous écrivez le fichier pour faire la distinction entre l'image du chat et l'image du pont. Utilisez 0 pour l'image du chat et 1 pour l'image du pont:

image_labels = {
    cat_in_snow : 0,
    williamsburg_bridge : 1,
}
# This is an example, just using the cat image.
image_string = open(cat_in_snow, 'rb').read()

label = image_labels[cat_in_snow]

# Create a dictionary with features that may be relevant.
def image_example(image_string, label):
  image_shape = tf.image.decode_jpeg(image_string).shape

  feature = {
      'height': _int64_feature(image_shape[0]),
      'width': _int64_feature(image_shape[1]),
      'depth': _int64_feature(image_shape[2]),
      'label': _int64_feature(label),
      'image_raw': _bytes_feature(image_string),
  }

  return tf.train.Example(features=tf.train.Features(feature=feature))

for line in str(image_example(image_string, label)).split('\n')[:15]:
  print(line)
print('...')
features {
  feature {
    key: "depth"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "height"
    value {
      int64_list {
        value: 213
      }
...

Notez que toutes les fonctionnalités sont désormais stockées dans le message tf.train.Example . Ensuite, fonctionnalisez le code ci-dessus et écrivez les exemples de messages dans un fichier nommé images.tfrecords :

# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.train.Example` messages.
# Then, write to a `.tfrecords` file.
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
  for filename, label in image_labels.items():
    image_string = open(filename, 'rb').read()
    tf_example = image_example(image_string, label)
    writer.write(tf_example.SerializeToString())
du -sh {record_file}
36K images.tfrecords

Lire le fichier TFRecord

Vous avez maintenant le images.tfrecords - images.tfrecords - et pouvez maintenant parcourir les enregistrements qu'il contient pour relire ce que vous avez écrit. Étant donné que dans cet exemple, vous ne reproduisez que l'image, la seule fonctionnalité dont vous aurez besoin est la chaîne d'image brute. Extrayez-le en utilisant les getters décrits ci-dessus, à savoir example.features.feature['image_raw'].bytes_list.value[0] . Vous pouvez également utiliser les étiquettes pour déterminer quel enregistrement est le chat et lequel est le pont:

raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')

# Create a dictionary describing the features.
image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  # Parse the input tf.train.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, image_feature_description)

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset
<MapDataset shapes: {depth: (), height: (), image_raw: (), label: (), width: ()}, types: {depth: tf.int64, height: tf.int64, image_raw: tf.string, label: tf.int64, width: tf.int64}>

Récupérez les images du fichier TFRecord:

for image_features in parsed_image_dataset:
  image_raw = image_features['image_raw'].numpy()
  display.display(display.Image(data=image_raw))

jpeg

jpeg