TFRecord y tf.train.Example

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar libreta

El formato TFRecord es un formato simple para almacenar una secuencia de registros binarios.

Los búferes de protocolo son una biblioteca multiplataforma y multilingüe para la serialización eficiente de datos estructurados.

Los mensajes de protocolo están definidos por archivos .proto , estos son a menudo la forma más fácil de entender un tipo de mensaje.

El mensaje tf.train.Example (o protobuf) es un tipo de mensaje flexible que representa una asignación {"string": value} . Está diseñado para usarse con TensorFlow y se usa en todas las API de nivel superior, como TFX .

Este cuaderno demuestra cómo crear, analizar y usar el mensaje tf.train.Example y luego serializar, escribir y leer mensajes tf.train.Example hacia y desde archivos .tfrecord .

Configuración

import tensorflow as tf

import numpy as np
import IPython.display as display

tf.train.Example

Tipos de datos para tf.train.Example

Fundamentalmente, un tf.train.Example es un mapeo {"string": tf.train.Feature} .

El tipo de mensaje tf.train.Feature puede aceptar uno de los tres tipos siguientes (consulte el archivo .proto como referencia). La mayoría de los otros tipos genéricos se pueden obligar a uno de estos:

  1. tf.train.BytesList (los siguientes tipos pueden ser forzados)

    • string
    • byte
  2. tf.train.FloatList (los siguientes tipos pueden ser forzados)

    • float ( float32 )
    • double ( float64 )
  3. tf.train.Int64List (los siguientes tipos pueden ser forzados)

    • bool
    • enum
    • int32
    • uint32
    • int64
    • uint64

Para convertir un tipo TensorFlow estándar en un tf.train.Example compatible con tf.train.Feature , puede usar las funciones de acceso directo a continuación. Tenga en cuenta que cada función toma un valor de entrada escalar y devuelve un tf.train.Feature que contiene uno de los tres tipos de list anteriores:

# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

A continuación se muestran algunos ejemplos de cómo funcionan estas funciones. Tenga en cuenta los diferentes tipos de entrada y los tipos de salida estandarizados. Si el tipo de entrada para una función no coincide con uno de los tipos coercibles indicados anteriormente, la función generará una excepción (por ejemplo _int64_feature(1.0) generará un error porque 1.0 es un flotante; por lo tanto, debe usarse con la función _float_feature en su lugar ):

print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))

print(_float_feature(np.exp(1)))

print(_int64_feature(True))
print(_int64_feature(1))
bytes_list {
  value: "test_string"
}

bytes_list {
  value: "test_bytes"
}

float_list {
  value: 2.7182817459106445
}

int64_list {
  value: 1
}

int64_list {
  value: 1
}

Todos los mensajes prototipo se pueden serializar en una cadena binaria utilizando el método .SerializeToString :

feature = _float_feature(np.exp(1))

feature.SerializeToString()
b'\x12\x06\n\x04T\xf8-@'

Creación de un mensaje tf.train.Example

Suponga que desea crear un mensaje tf.train.Example a partir de datos existentes. En la práctica, el conjunto de datos puede provenir de cualquier lugar, pero el procedimiento para crear el mensaje tf.train.Example a partir de una sola observación será el mismo:

  1. Dentro de cada observación, cada valor debe convertirse en un tf.train.Feature que contenga uno de los 3 tipos compatibles, utilizando una de las funciones anteriores.

  2. Usted crea un mapa (diccionario) desde la cadena del nombre de la función hasta el valor de la función codificado producido en el n.º 1.

  3. El mapa generado en el paso 2 se convierte en un mensaje de Features .

En este cuaderno, creará un conjunto de datos utilizando NumPy.

Este conjunto de datos tendrá 4 características:

  • una característica booleana, False o True con la misma probabilidad
  • una característica entera uniformemente elegida al azar de [0, 5]
  • una característica de cadena generada a partir de una tabla de cadenas mediante el uso de la característica de entero como índice
  • una característica flotante de una distribución normal estándar

Considere una muestra que consta de 10,000 observaciones distribuidas de manera independiente e idéntica de cada una de las distribuciones anteriores:

# The number of observations in the dataset.
n_observations = int(1e4)

# Boolean feature, encoded as False or True.
feature0 = np.random.choice([False, True], n_observations)

# Integer feature, random from 0 to 4.
feature1 = np.random.randint(0, 5, n_observations)

# String feature.
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]

# Float feature, from a standard normal distribution.
feature3 = np.random.randn(n_observations)

Cada una de estas características se puede convertir en un tipo compatible con tf.train.Example usando uno de _bytes_feature , _float_feature , _int64_feature . A continuación, puede crear un mensaje tf.train.Example a partir de estas funciones codificadas:

def serialize_example(feature0, feature1, feature2, feature3):
  """
  Creates a tf.train.Example message ready to be written to a file.
  """
  # Create a dictionary mapping the feature name to the tf.train.Example-compatible
  # data type.
  feature = {
      'feature0': _int64_feature(feature0),
      'feature1': _int64_feature(feature1),
      'feature2': _bytes_feature(feature2),
      'feature3': _float_feature(feature3),
  }

  # Create a Features message using tf.train.Example.

  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

Por ejemplo, suponga que tiene una sola observación del conjunto de datos, [False, 4, bytes('goat'), 0.9876] . Puede crear e imprimir el mensaje tf.train.Example para esta observación usando create_message() . Cada observación individual se escribirá como un mensaje de Features según lo anterior. Tenga en cuenta que el mensaje tf.train.Example es solo un envoltorio alrededor del mensaje Features :

# This is an example observation from the dataset.

example_observation = []

serialized_example = serialize_example(False, 4, b'goat', 0.9876)
serialized_example
b'\nR\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?'

Para decodificar el mensaje, utilice el método tf.train.Example.FromString .

example_proto = tf.train.Example.FromString(serialized_example)
example_proto
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 4
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "goat"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.9876000285148621
      }
    }
  }
}

Detalles del formato de TFRecords

Un archivo TFRecord contiene una secuencia de registros. El archivo solo se puede leer secuencialmente.

Cada registro contiene una cadena de bytes, para la carga útil de datos, más la longitud de datos y hashes CRC-32C (CRC de 32 bits que utiliza el polinomio de Castagnoli ) para la verificación de integridad.

Cada registro se almacena en los siguientes formatos:

uint64 length
uint32 masked_crc32_of_length
byte   data[length]
uint32 masked_crc32_of_data

Los registros se concatenan para producir el archivo. Los CRC se describen aquí , y la máscara de un CRC es:

masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul

Archivos TFRecord usando tf.data

El módulo tf.data también proporciona herramientas para leer y escribir datos en TensorFlow.

Escribir un archivo TFRecord

La forma más fácil de obtener los datos en un conjunto de datos es usar el método from_tensor_slices .

Aplicado a una matriz, devuelve un conjunto de datos de escalares:

tf.data.Dataset.from_tensor_slices(feature1)
<TensorSliceDataset element_spec=TensorSpec(shape=(), dtype=tf.int64, name=None)>

Aplicado a una tupla de matrices, devuelve un conjunto de datos de tuplas:

features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))
features_dataset
<TensorSliceDataset element_spec=(TensorSpec(shape=(), dtype=tf.bool, name=None), TensorSpec(shape=(), dtype=tf.int64, name=None), TensorSpec(shape=(), dtype=tf.string, name=None), TensorSpec(shape=(), dtype=tf.float64, name=None))>
# Use `take(1)` to only pull one example from the dataset.
for f0,f1,f2,f3 in features_dataset.take(1):
  print(f0)
  print(f1)
  print(f2)
  print(f3)
tf.Tensor(False, shape=(), dtype=bool)
tf.Tensor(4, shape=(), dtype=int64)
tf.Tensor(b'goat', shape=(), dtype=string)
tf.Tensor(0.5251196235602504, shape=(), dtype=float64)

Utilice el método tf.data.Dataset.map para aplicar una función a cada elemento de un conjunto de Dataset .

La función mapeada debe operar en el modo gráfico de TensorFlow; debe operar y devolver tf.Tensors . Una función que no es de tensor, como serialize_example , se puede encapsular con tf.py_function para que sea compatible.

El uso de tf.py_function requiere especificar la forma y el tipo de información que de otro modo no estaría disponible:

def tf_serialize_example(f0,f1,f2,f3):
  tf_string = tf.py_function(
    serialize_example,
    (f0, f1, f2, f3),  # Pass these args to the above function.
    tf.string)      # The return type is `tf.string`.
  return tf.reshape(tf_string, ()) # The result is a scalar.
tf_serialize_example(f0, f1, f2, f3)
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04=n\x06?\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04'>

Aplique esta función a cada elemento del conjunto de datos:

serialized_features_dataset = features_dataset.map(tf_serialize_example)
serialized_features_dataset
<MapDataset element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>
def generator():
  for features in features_dataset:
    yield serialize_example(*features)
serialized_features_dataset = tf.data.Dataset.from_generator(
    generator, output_types=tf.string, output_shapes=())
serialized_features_dataset
<FlatMapDataset element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>

Y escríbalos en un archivo TFRecord:

filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)
WARNING:tensorflow:From /tmp/ipykernel_25215/3575438268.py:2: TFRecordWriter.__init__ (from tensorflow.python.data.experimental.ops.writers) is deprecated and will be removed in a future version.
Instructions for updating:
To write TFRecords to disk, use `tf.io.TFRecordWriter`. To save and load the contents of a dataset, use `tf.data.experimental.save` and `tf.data.experimental.load`

Lectura de un archivo TFRecord

También puede leer el archivo TFRecord usando la clase tf.data.TFRecordDataset .

Puede encontrar más información sobre cómo consumir archivos TFRecord usando tf.data en la guía tf.data: Build TensorFlow input pipelines .

El uso TFRecordDataset s puede ser útil para estandarizar los datos de entrada y optimizar el rendimiento.

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>

En este punto, el conjunto de datos contiene mensajes tf.train.Example serializados. Cuando se repite, los devuelve como tensores de cadena escalares.

Use el método .take para mostrar solo los primeros 10 registros.

for raw_record in raw_dataset.take(10):
  print(repr(raw_record))
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04=n\x06?'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x9d\xfa\x98\xbe\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x13\n\x08feature2\x12\x07\n\x05\n\x03dog\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04a\xc0r?\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x92Q(?'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04>\xc0\xe5>\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04I!\xde\xbe\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xe0\x1a\xab\xbf\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x87\xb2\xd7?\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04n\xe19>\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x1as\xd9\xbf\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat'>

Estos tensores se pueden analizar usando la siguiente función. Tenga en cuenta que la tf.data.Dataset feature_description usa la ejecución de gráficos y necesita esta descripción para construir su forma y tipo de firma:

# Create a description of the features.
feature_description = {
    'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
}

def _parse_function(example_proto):
  # Parse the input `tf.train.Example` proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, feature_description)

Alternativamente, use tf.parse example para analizar todo el lote a la vez. Aplique esta función a cada elemento del conjunto de datos utilizando el método tf.data.Dataset.map :

parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset
<MapDataset element_spec={'feature0': TensorSpec(shape=(), dtype=tf.int64, name=None), 'feature1': TensorSpec(shape=(), dtype=tf.int64, name=None), 'feature2': TensorSpec(shape=(), dtype=tf.string, name=None), 'feature3': TensorSpec(shape=(), dtype=tf.float32, name=None)}>

Utilice la ejecución entusiasta para mostrar las observaciones en el conjunto de datos. Hay 10 000 observaciones en este conjunto de datos, pero solo mostrará las primeras 10. Los datos se muestran como un diccionario de características. Cada elemento es un tf.Tensor y el elemento numpy de este tensor muestra el valor de la función:

for parsed_record in parsed_dataset.take(10):
  print(repr(parsed_record))
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.5251196>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.29878703>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'dog'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.94824797>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.65749466>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.44873232>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.4338477>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-1.3367577>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=1.6851357>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.18152401>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-1.6988251>}

Aquí, la función tf.parse_example desempaqueta los campos tf.train.Example en tensores estándar.

Archivos TFRecord en Python

El módulo tf.io también contiene funciones de Python puro para leer y escribir archivos TFRecord.

Escribir un archivo TFRecord

A continuación, escriba las 10 000 observaciones en el archivo test.tfrecord . Cada observación se convierte en un mensaje tf.train.Example y luego se escribe en un archivo. A continuación, puede verificar que se ha creado el archivo test.tfrecord :

# Write the `tf.train.Example` observations to the file.
with tf.io.TFRecordWriter(filename) as writer:
  for i in range(n_observations):
    example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
    writer.write(example)
du -sh {filename}
984K    test.tfrecord

Lectura de un archivo TFRecord

Estos tensores serializados se pueden analizar fácilmente usando tf.train.Example.ParseFromString :

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>
for raw_record in raw_dataset.take(1):
  example = tf.train.Example()
  example.ParseFromString(raw_record.numpy())
  print(example)
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 4
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "goat"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.5251196026802063
      }
    }
  }
}

Eso devuelve un proto tf.train.Example que es difícil de usar tal cual, pero es fundamentalmente una representación de un:

Dict[str,
     Union[List[float],
           List[int],
           List[str]]]

El siguiente código convierte manualmente el Example en un diccionario de matrices NumPy, sin usar TensorFlow Ops. Consulte el archivo PROTO para obtener detalles.

result = {}
# example.features.feature is the dictionary
for key, feature in example.features.feature.items():
  # The values are the Feature objects which contain a `kind` which contains:
  # one of three fields: bytes_list, float_list, int64_list

  kind = feature.WhichOneof('kind')
  result[key] = np.array(getattr(feature, kind).value)

result
{'feature3': array([0.5251196]),
 'feature1': array([4]),
 'feature0': array([0]),
 'feature2': array([b'goat'], dtype='|S4')}

Tutorial: lectura y escritura de datos de imagen

Este es un ejemplo completo de cómo leer y escribir datos de imágenes usando TFRecords. Usando una imagen como datos de entrada, escribirá los datos como un archivo TFRecord, luego volverá a leer el archivo y mostrará la imagen.

Esto puede ser útil si, por ejemplo, desea utilizar varios modelos en el mismo conjunto de datos de entrada. En lugar de almacenar los datos de la imagen sin procesar, se pueden preprocesar en el formato TFRecords, y eso se puede usar en todo el procesamiento y modelado posteriores.

Primero, descarguemos esta imagen de un gato en la nieve y esta foto del Puente Williamsburg, NYC en construcción.

Obtener las imágenes

cat_in_snow  = tf.keras.utils.get_file(
    '320px-Felis_catus-cat_on_snow.jpg',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')

williamsburg_bridge = tf.keras.utils.get_file(
    '194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg
24576/17858 [=========================================] - 0s 0us/step
32768/17858 [=======================================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg
16384/15477 [===============================] - 0s 0us/step
24576/15477 [===============================================] - 0s 0us/step
display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))

jpeg

display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">From Wikimedia</a>'))

jpeg

Escribir el archivo TFRecord

Como antes, codifique las características como tipos compatibles con tf.train.Example . Esto almacena la función de cadena de imagen sin procesar, así como la función de altura, ancho, profundidad y label arbitraria. Este último se usa cuando escribe el archivo para distinguir entre la imagen del gato y la imagen del puente. Use 0 para la imagen del gato y 1 para la imagen del puente:

image_labels = {
    cat_in_snow : 0,
    williamsburg_bridge : 1,
}
# This is an example, just using the cat image.
image_string = open(cat_in_snow, 'rb').read()

label = image_labels[cat_in_snow]

# Create a dictionary with features that may be relevant.
def image_example(image_string, label):
  image_shape = tf.io.decode_jpeg(image_string).shape

  feature = {
      'height': _int64_feature(image_shape[0]),
      'width': _int64_feature(image_shape[1]),
      'depth': _int64_feature(image_shape[2]),
      'label': _int64_feature(label),
      'image_raw': _bytes_feature(image_string),
  }

  return tf.train.Example(features=tf.train.Features(feature=feature))

for line in str(image_example(image_string, label)).split('\n')[:15]:
  print(line)
print('...')
features {
  feature {
    key: "depth"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "height"
    value {
      int64_list {
        value: 213
      }
...

Observe que todas las características ahora están almacenadas en el mensaje tf.train.Example . A continuación, funcionalice el código anterior y escriba los mensajes de ejemplo en un archivo llamado images.tfrecords :

# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.train.Example` messages.
# Then, write to a `.tfrecords` file.
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
  for filename, label in image_labels.items():
    image_string = open(filename, 'rb').read()
    tf_example = image_example(image_string, label)
    writer.write(tf_example.SerializeToString())
du -sh {record_file}
36K images.tfrecords

Leer el archivo TFRecord

Ahora tiene el archivo, images.tfrecords , y ahora puede iterar sobre los registros que contiene para volver a leer lo que escribió. Dado que en este ejemplo solo reproducirá la imagen, la única característica que necesitará es la cadena de imagen sin formato. Extráigalo utilizando los captadores descritos anteriormente, a saber, example.features.feature['image_raw'].bytes_list.value[0] . También puede usar las etiquetas para determinar qué registro es el gato y cuál es el puente:

raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')

# Create a dictionary describing the features.
image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  # Parse the input tf.train.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, image_feature_description)

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset
<MapDataset element_spec={'depth': TensorSpec(shape=(), dtype=tf.int64, name=None), 'height': TensorSpec(shape=(), dtype=tf.int64, name=None), 'image_raw': TensorSpec(shape=(), dtype=tf.string, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None), 'width': TensorSpec(shape=(), dtype=tf.int64, name=None)}>

Recupera las imágenes del archivo TFRecord:

for image_features in parsed_image_dataset:
  image_raw = image_features['image_raw'].numpy()
  display.display(display.Image(data=image_raw))

jpeg

jpeg