TFRecord e tf.train.Example

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

O formato TFRecord é um formato simples para armazenar uma sequência de registros binários.

Os buffers de protocolo são uma biblioteca de plataforma cruzada e linguagem cruzada para serialização eficiente de dados estruturados.

As mensagens de protocolo são definidas por arquivos .proto , geralmente a maneira mais fácil de entender um tipo de mensagem.

A mensagem tf.train.Example (ou protobuf) é um tipo de mensagem flexível que representa um mapeamento {"string": value} . Ele foi projetado para uso com o TensorFlow e em todas as APIs de nível superior, como TFX .

Este bloco de notas demonstra como criar, analisar e usar a mensagem tf.train.Example e, em seguida, serializar, gravar e ler mensagens tf.train.Example de e para arquivos .tfrecord .

Configurar

import tensorflow as tf

import numpy as np
import IPython.display as display

tf.train.Example

Tipos de dados para tf.train.Example

Fundamentalmente, um tf.train.Example é um mapeamento {"string": tf.train.Feature} .

O tipo de mensagem tf.train.Feature pode aceitar um dos três tipos a seguir (consulte o arquivo .proto para referência). A maioria dos outros tipos genéricos pode ser forçada a um destes:

  1. tf.train.BytesList (os seguintes tipos podem ser forçados)

    • string
    • byte
  2. tf.train.FloatList (os seguintes tipos podem ser forçados)

    • float ( float32 )
    • double ( float64 )
  3. tf.train.Int64List (os seguintes tipos podem ser forçados)

    • bool
    • enum
    • int32
    • uint32
    • int64
    • uint64

Para converter um tipo padrão do TensorFlow em um tf.train.Example compatível com tf.train.Feature , você pode usar as funções de atalho abaixo. Observe que cada função recebe um valor de entrada escalar e retorna um tf.train.Feature contendo um dos três tipos de list acima:

# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

Abaixo estão alguns exemplos de como essas funções funcionam. Observe os diversos tipos de entrada e os tipos de saída padronizados. Se o tipo de entrada para uma função não corresponder a um dos tipos coercíveis declarados acima, a função gerará uma exceção (por exemplo, _int64_feature(1.0) apresentará um erro porque 1.0 é um float - portanto, deve ser usado com a função _float_feature ):

print(_bytes_feature(b'test_string'))
print(_bytes_feature(u'test_bytes'.encode('utf-8')))

print(_float_feature(np.exp(1)))

print(_int64_feature(True))
print(_int64_feature(1))
bytes_list {
  value: "test_string"
}

bytes_list {
  value: "test_bytes"
}

float_list {
  value: 2.7182817459106445
}

int64_list {
  value: 1
}

int64_list {
  value: 1
}

Todas as mensagens proto podem ser serializadas em uma string binária usando o método .SerializeToString :

feature = _float_feature(np.exp(1))

feature.SerializeToString()
b'\x12\x06\n\x04T\xf8-@'

Criação de uma mensagem tf.train.Example

Suponha que você queira criar uma mensagem tf.train.Example partir de dados existentes. Na prática, o conjunto de dados pode vir de qualquer lugar, mas o procedimento de criação da mensagem tf.train.Example partir de uma única observação será o mesmo:

  1. Dentro de cada observação, cada valor precisa ser convertido em um tf.train.Feature contendo um dos 3 tipos compatíveis, usando uma das funções acima.

  2. Você cria um mapa (dicionário) da string do nome do recurso para o valor do recurso codificado produzido em # 1.

  3. O mapa produzido na etapa 2 é convertido em uma mensagem de Features .

Neste bloco de notas, você criará um conjunto de dados usando NumPy.

Este conjunto de dados terá 4 recursos:

  • um recurso booleano, False ou True com igual probabilidade
  • um recurso de número inteiro escolhido de maneira aleatória e uniforme de [0, 5]
  • um recurso de string gerado a partir de uma tabela de string usando o recurso inteiro como um índice
  • um recurso flutuante de uma distribuição normal padrão

Considere uma amostra que consiste em 10.000 observações distribuídas de forma independente e idêntica de cada uma das distribuições acima:

# The number of observations in the dataset.
n_observations = int(1e4)

# Boolean feature, encoded as False or True.
feature0 = np.random.choice([False, True], n_observations)

# Integer feature, random from 0 to 4.
feature1 = np.random.randint(0, 5, n_observations)

# String feature.
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]

# Float feature, from a standard normal distribution.
feature3 = np.random.randn(n_observations)

Cada um desses recursos pode ser forçado a um tipo compatível com tf.train.Example usando um dos _bytes_feature , _float_feature , _int64_feature . Você pode então criar uma mensagem tf.train.Example partir destes recursos codificados:

def serialize_example(feature0, feature1, feature2, feature3):
  """
  Creates a tf.train.Example message ready to be written to a file.
  """
  # Create a dictionary mapping the feature name to the tf.train.Example-compatible
  # data type.
  feature = {
      'feature0': _int64_feature(feature0),
      'feature1': _int64_feature(feature1),
      'feature2': _bytes_feature(feature2),
      'feature3': _float_feature(feature3),
  }

  # Create a Features message using tf.train.Example.

  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

Por exemplo, suponha que você tenha uma única observação do conjunto de dados, [False, 4, bytes('goat'), 0.9876] . Você pode criar e imprimir a mensagem tf.train.Example para esta observação usando create_message() . Cada observação será escrita como uma mensagem de Features acordo com o acima. Observe que a mensagem tf.train.Example é apenas um wrapper em torno da mensagem Features :

# This is an example observation from the dataset.

example_observation = []

serialized_example = serialize_example(False, 4, b'goat', 0.9876)
serialized_example
b'\nR\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04[\xd3|?\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat'

Para decodificar a mensagem, use o método tf.train.Example.FromString .

example_proto = tf.train.Example.FromString(serialized_example)
example_proto
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 4
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "goat"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.9876000285148621
      }
    }
  }
}

Detalhes do formato TFRecords

Um arquivo TFRecord contém uma sequência de registros. O arquivo só pode ser lido sequencialmente.

Cada registro contém uma string de bytes, para a carga útil de dados, mais o comprimento dos dados e hashes CRC-32C (CRC de 32 bits usando o polinômio Castagnoli ) para verificação de integridade.

Cada registro é armazenado nos seguintes formatos:

uint64 length
uint32 masked_crc32_of_length
byte   data[length]
uint32 masked_crc32_of_data

Os registros são concatenados para produzir o arquivo. Os CRCs são descritos aqui , e a máscara de um CRC é:

masked_crc = ((crc >> 15) | (crc << 17)) + 0xa282ead8ul

Arquivos tf.data usando tf.data

O módulo tf.data também fornece ferramentas para ler e gravar dados no TensorFlow.

Gravando um arquivo TFRecord

A maneira mais fácil de inserir os dados em um conjunto de dados é usar o método from_tensor_slices .

Aplicado a uma matriz, ele retorna um conjunto de dados de escalares:

tf.data.Dataset.from_tensor_slices(feature1)
<TensorSliceDataset shapes: (), types: tf.int64>

Aplicado a uma tupla de matrizes, ele retorna um conjunto de dados de tuplas:

features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))
features_dataset
<TensorSliceDataset shapes: ((), (), (), ()), types: (tf.bool, tf.int64, tf.string, tf.float64)>
# Use `take(1)` to only pull one example from the dataset.
for f0,f1,f2,f3 in features_dataset.take(1):
  print(f0)
  print(f1)
  print(f2)
  print(f3)
tf.Tensor(False, shape=(), dtype=bool)
tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(b'horse', shape=(), dtype=string)
tf.Tensor(0.3707167206984876, shape=(), dtype=float64)

Use o método tf.data.Dataset.map para aplicar uma função a cada elemento de um Dataset .

A função mapeada deve operar no modo de gráfico TensorFlow - deve operar e retornar tf.Tensors . Uma função não tensor, como serialize_example , pode ser tf.py_function com tf.py_function para torná-la compatível.

O uso de tf.py_function requer a especificação da forma e do tipo de informação que, de outra forma, não está disponível:

def tf_serialize_example(f0,f1,f2,f3):
  tf_string = tf.py_function(
    serialize_example,
    (f0, f1, f2, f3),  # Pass these args to the above function.
    tf.string)      # The return type is `tf.string`.
  return tf.reshape(tf_string, ()) # The result is a scalar.
tf_serialize_example(f0, f1, f2, f3)
<tf.Tensor: shape=(), dtype=string, numpy=b'\nS\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x95\xce\xbd>\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03'>

Aplique esta função a cada elemento no conjunto de dados:

serialized_features_dataset = features_dataset.map(tf_serialize_example)
serialized_features_dataset
<MapDataset shapes: (), types: tf.string>
def generator():
  for features in features_dataset:
    yield serialize_example(*features)
serialized_features_dataset = tf.data.Dataset.from_generator(
    generator, output_types=tf.string, output_shapes=())
serialized_features_dataset
<FlatMapDataset shapes: (), types: tf.string>

E grave-os em um arquivo TFRecord:

filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)

Lendo um arquivo TFRecord

Você também pode ler o arquivo TFRecord usando a classe tf.data.TFRecordDataset .

Mais informações sobre como consumir arquivos TFRecord usando tf.data podem ser encontradas em tf.data: Guia de pipelines de entrada do TensorFlow de construção .

Usar TFRecordDataset s pode ser útil para padronizar dados de entrada e otimizar o desempenho.

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>

Neste ponto, o conjunto de dados contém mensagens tf.train.Example serializadas. Quando iterado, ele os retorna como tensores de string escalar.

Use o método .take para mostrar apenas os primeiros 10 registros.

for raw_record in raw_dataset.take(10):
  print(repr(raw_record))
<tf.Tensor: shape=(), dtype=string, numpy=b'\nS\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x95\xce\xbd>\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nS\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04iX\x9a\xbe\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nS\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xb6\xa2\xb5\xbf\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04>!\x84\xbc'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nR\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xa9\xdcE?'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x98\x8bb=\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00'>
<tf.Tensor: shape=(), dtype=string, numpy=b"\nR\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\xb6\xe2'\xbf\n\x14\n\x08feature2\x12\x08\n\x06\n\x04goat\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x01\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x04">
<tf.Tensor: shape=(), dtype=string, numpy=b'\nS\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x15\n\x08feature2\x12\t\n\x07\n\x05horse\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04\x9c\xc4I>\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x03'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nQ\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04p\xd3\xbd\xbc\n\x13\n\x08feature2\x12\x07\n\x05\n\x03cat\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x00'>
<tf.Tensor: shape=(), dtype=string, numpy=b'\nU\n\x14\n\x08feature3\x12\x08\x12\x06\n\x04`\x8bp?\n\x11\n\x08feature1\x12\x05\x1a\x03\n\x01\x02\n\x11\n\x08feature0\x12\x05\x1a\x03\n\x01\x00\n\x17\n\x08feature2\x12\x0b\n\t\n\x07chicken'>

Esses tensores podem ser analisados ​​usando a função abaixo. Observe que feature_description é necessária aqui porquetf.data.Dataset s usa a execução de gráfico e precisa desta descrição para construir sua forma e assinatura de tipo:

# Create a description of the features.
feature_description = {
    'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),
    'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
}

def _parse_function(example_proto):
  # Parse the input `tf.train.Example` proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, feature_description)

Como alternativa, use o tf.parse example para analisar todo o lote de uma vez. Aplique esta função a cada item no conjunto de dados usando o método tf.data.Dataset.map :

parsed_dataset = raw_dataset.map(_parse_function)
parsed_dataset
<MapDataset shapes: {feature0: (), feature1: (), feature2: (), feature3: ()}, types: {feature0: tf.int64, feature1: tf.int64, feature2: tf.string, feature3: tf.float32}>

Use a execução rápida para exibir as observações no conjunto de dados. Existem 10.000 observações neste conjunto de dados, mas você só exibirá as primeiras 10. Os dados são exibidos como um dicionário de recursos. Cada item é um tf.Tensor , e o elemento numpy desse tensor exibe o valor do recurso:

for parsed_record in parsed_dataset.take(10):
  print(repr(parsed_record))
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=3>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'horse'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.37071672>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=3>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'horse'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.30145577>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=3>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'horse'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-1.419028>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.016129132>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.77289826>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.05530891>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'goat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.6558031>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=3>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'horse'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.19703907>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'cat'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=-0.02317211>}
{'feature0': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'feature1': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'feature2': <tf.Tensor: shape=(), dtype=string, numpy=b'chicken'>, 'feature3': <tf.Tensor: shape=(), dtype=float32, numpy=0.9396267>}

Aqui, a função tf.parse_example descompacta os campos tf.train.Example em tensores padrão.

Arquivos TFRecord em Python

O módulo tf.io também contém funções Python puro para leitura e gravação de arquivos TFRecord.

Gravando um arquivo TFRecord

Em seguida, escreva as 10.000 observações no arquivo test.tfrecord . Cada observação é convertida em uma mensagem tf.train.Example , em seguida, gravada em um arquivo. Você pode então verificar se o arquivo test.tfrecord foi criado:

# Write the `tf.train.Example` observations to the file.
with tf.io.TFRecordWriter(filename) as writer:
  for i in range(n_observations):
    example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
    writer.write(example)
du -sh {filename}
984K    test.tfrecord

Lendo um arquivo TFRecord

Esses tensores serializados podem ser facilmente analisados ​​usando tf.train.Example.ParseFromString :

filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
raw_dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>
for raw_record in raw_dataset.take(1):
  example = tf.train.Example()
  example.ParseFromString(raw_record.numpy())
  print(example)
features {
  feature {
    key: "feature0"
    value {
      int64_list {
        value: 0
      }
    }
  }
  feature {
    key: "feature1"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "feature2"
    value {
      bytes_list {
        value: "horse"
      }
    }
  }
  feature {
    key: "feature3"
    value {
      float_list {
        value: 0.37071672081947327
      }
    }
  }
}

Passo a passo: Leitura e gravação de dados de imagem

Este é um exemplo completo de como ler e gravar dados de imagem usando TFRecords. Usando uma imagem como dados de entrada, você gravará os dados como um arquivo TFRecord e, em seguida, lerá o arquivo de volta e exibirá a imagem.

Isso pode ser útil se, por exemplo, você quiser usar vários modelos no mesmo conjunto de dados de entrada. Em vez de armazenar os dados brutos da imagem, eles podem ser pré-processados ​​no formato TFRecords e podem ser usados ​​em todos os processamentos e modelagens posteriores.

Primeiro, vamos baixar esta imagem de um gato na neve e esta foto da ponte Williamsburg, NYC em construção.

Pegue as imagens

cat_in_snow  = tf.keras.utils.get_file(
    '320px-Felis_catus-cat_on_snow.jpg',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')

williamsburg_bridge = tf.keras.utils.get_file(
    '194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg
24576/17858 [=========================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg
16384/15477 [===============================] - 0s 0us/step
display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))

JPEG

display.display(display.Image(filename=williamsburg_bridge))
display.display(display.HTML('<a "href=https://commons.wikimedia.org/wiki/File:New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg">From Wikimedia</a>'))

JPEG

Grave o arquivo TFRecord

Como antes, codifique os recursos como tipos compatíveis com tf.train.Example . Isso armazena o recurso de string de imagem bruta, bem como a altura, largura, profundidade e recurso de label arbitrário. O último é usado quando você grava o arquivo para distinguir entre a imagem do gato e a imagem da ponte. Use 0 para a imagem do gato e 1 para a imagem da ponte:

image_labels = {
    cat_in_snow : 0,
    williamsburg_bridge : 1,
}
# This is an example, just using the cat image.
image_string = open(cat_in_snow, 'rb').read()

label = image_labels[cat_in_snow]

# Create a dictionary with features that may be relevant.
def image_example(image_string, label):
  image_shape = tf.io.decode_jpeg(image_string).shape

  feature = {
      'height': _int64_feature(image_shape[0]),
      'width': _int64_feature(image_shape[1]),
      'depth': _int64_feature(image_shape[2]),
      'label': _int64_feature(label),
      'image_raw': _bytes_feature(image_string),
  }

  return tf.train.Example(features=tf.train.Features(feature=feature))

for line in str(image_example(image_string, label)).split('\n')[:15]:
  print(line)
print('...')
features {
  feature {
    key: "depth"
    value {
      int64_list {
        value: 3
      }
    }
  }
  feature {
    key: "height"
    value {
      int64_list {
        value: 213
      }
...

Observe que todos os recursos agora estão armazenados na mensagem tf.train.Example . Em seguida, funcionalize o código acima e grave as mensagens de exemplo em um arquivo denominado images.tfrecords :

# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.train.Example` messages.
# Then, write to a `.tfrecords` file.
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
  for filename, label in image_labels.items():
    image_string = open(filename, 'rb').read()
    tf_example = image_example(image_string, label)
    writer.write(tf_example.SerializeToString())
du -sh {record_file}
36K images.tfrecords

Leia o arquivo TFRecord

Agora você tem o arquivo - images.tfrecords - e agora pode iterar sobre os registros nele para ler de volta o que você escreveu. Dado que, neste exemplo, você apenas reproduzirá a imagem, o único recurso de que precisará é a string de imagem bruta. Extraia-o usando os getters descritos acima, a saber, example.features.feature['image_raw'].bytes_list.value[0] . Você também pode usar os rótulos para determinar qual registro é o gato e qual é a ponte:

raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')

# Create a dictionary describing the features.
image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  # Parse the input tf.train.Example proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, image_feature_description)

parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
parsed_image_dataset
<MapDataset shapes: {depth: (), height: (), image_raw: (), label: (), width: ()}, types: {depth: tf.int64, height: tf.int64, image_raw: tf.string, label: tf.int64, width: tf.int64}>

Recupere as imagens do arquivo TFRecord:

for image_features in parsed_image_dataset:
  image_raw = image_features['image_raw'].numpy()
  display.display(display.Image(data=image_raw))

JPEG

JPEG