หน้านี้ได้รับการแปลโดย Cloud Translation API
Switch to English

tf.data: สร้างไปป์ไลน์อินพุต TensorFlow

ดูใน TensorFlow.org เรียกใช้ใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดสมุดบันทึก

tf.data API ช่วยให้คุณสร้างท่อป้อนข้อมูลที่ซับซ้อนจากชิ้นส่วนที่เรียบง่ายและนำกลับมาใช้ใหม่ได้ ตัวอย่างเช่นไปป์ไลน์สำหรับโมเดลรูปภาพอาจรวมข้อมูลจากไฟล์ในระบบไฟล์แบบกระจายใช้การรบกวนแบบสุ่มกับแต่ละรูปภาพและรวมรูปภาพที่เลือกแบบสุ่มเป็นชุดสำหรับการฝึกอบรม ไปป์ไลน์สำหรับแบบจำลองข้อความอาจเกี่ยวข้องกับการแยกสัญลักษณ์จากข้อมูลข้อความดิบการแปลงเป็นการฝังตัวระบุด้วยตารางการค้นหาและการรวมลำดับที่มีความยาวต่างกันเข้าด้วยกัน tf.data API ทำให้สามารถจัดการข้อมูลจำนวนมากอ่านจากรูปแบบข้อมูลที่แตกต่างกันและทำการแปลงที่ซับซ้อนได้

tf.data API แนะนำ tf.data.Dataset ข้อมูลที่เป็นนามธรรมที่แสดงถึงลำดับขององค์ประกอบซึ่งแต่ละองค์ประกอบประกอบด้วยส่วนประกอบอย่างน้อยหนึ่งรายการ ตัวอย่างเช่นในไปป์ไลน์รูปภาพองค์ประกอบอาจเป็นตัวอย่างการฝึกอบรมเดียวโดยมีส่วนประกอบเทนเซอร์ที่เป็นตัวแทนของรูปภาพและป้ายกำกับ

มีสองวิธีที่แตกต่างกันในการสร้างชุดข้อมูล:

  • แหล่ง ข้อมูลสร้าง Dataset จากข้อมูลที่จัดเก็บในหน่วยความจำหรือในไฟล์อย่างน้อยหนึ่งไฟล์

  • การ แปลง ข้อมูลสร้างชุดข้อมูลจากวัตถุ tf.data.Dataset อย่างน้อยหนึ่งรายการ

import tensorflow as tf
import pathlib
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

np.set_printoptions(precision=4)

กลศาสตร์พื้นฐาน

ในการสร้างไปป์ไลน์อินพุตคุณต้องเริ่มต้นด้วย แหล่ง ข้อมูล ตัวอย่างเช่นในการสร้าง Dataset จากข้อมูลในหน่วยความจำคุณสามารถใช้ tf.data.Dataset.from_tensors() หรือ tf.data.Dataset.from_tensor_slices() หรือหากข้อมูลอินพุตของคุณถูกเก็บไว้ในไฟล์ในรูปแบบ TFRecord ที่แนะนำคุณสามารถใช้ tf.data.TFRecordDataset()

เมื่อคุณมีออบเจ็กต์ Dataset แล้วคุณสามารถ แปลง เป็น Dataset ใหม่ได้โดยการเชื่อมเมธอดการเรียกใช้อ็อบเจ็กต์ tf.data.Dataset ตัวอย่างเช่นคุณสามารถใช้การแปลงต่อองค์ประกอบเช่น Dataset.map() และการแปลงหลายองค์ประกอบเช่น Dataset.batch() ดูเอกสารสำหรับ tf.data.Dataset สำหรับรายการการแปลงทั้งหมด

วัตถุ Dataset เป็น Python ที่ทำซ้ำได้ สิ่งนี้ทำให้สามารถใช้องค์ประกอบโดยใช้ for loop:

dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
dataset
<TensorSliceDataset shapes: (), types: tf.int32>
for elem in dataset:
  print(elem.numpy())
8
3
0
8
2
1

หรือโดยการสร้างตัววนซ้ำ Python อย่างชัดเจนโดยใช้ iter และใช้องค์ประกอบโดยใช้ next :

it = iter(dataset)

print(next(it).numpy())
8

หรืออีกวิธีหนึ่งคือสามารถใช้องค์ประกอบชุดข้อมูลโดยใช้การ reduce แปลงซึ่งจะลดองค์ประกอบทั้งหมดเพื่อให้ได้ผลลัพธ์เดียว ตัวอย่างต่อไปนี้แสดงให้เห็นถึงวิธีการใช้การแปลงแบบ reduce เพื่อคำนวณผลรวมของชุดข้อมูลของจำนวนเต็ม

print(dataset.reduce(0, lambda state, value: state + value).numpy())
22

โครงสร้างชุดข้อมูล

ชุดข้อมูลประกอบด้วยองค์ประกอบที่แต่ละส่วนมีโครงสร้าง (ซ้อนกัน) เหมือนกันและส่วนประกอบแต่ละส่วนของโครงสร้างสามารถเป็นประเภทใดก็ได้ที่แสดงโดย tf.TypeSpec รวมถึง tf.Tensor , tf.sparse.SparseTensor , tf.RaggedTensor , tf.TensorArray หรือ tf.data.Dataset

คุณสมบัติ Dataset.element_spec ช่วยให้คุณตรวจสอบประเภทขององค์ประกอบแต่ละองค์ประกอบ คุณสมบัติส่งคืน โครงสร้างที่ซ้อนกัน ของ tf.TypeSpec โดยจับคู่โครงสร้างขององค์ประกอบซึ่งอาจเป็นส่วนประกอบเดียวทูเพิลของคอมโพเนนต์หรือทูเพิลที่ซ้อนกันของคอมโพเนนต์ ตัวอย่างเช่น:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))

dataset1.element_spec
TensorSpec(shape=(10,), dtype=tf.float32, name=None)
dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2.element_spec
(TensorSpec(shape=(), dtype=tf.float32, name=None),
 TensorSpec(shape=(100,), dtype=tf.int32, name=None))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3.element_spec
(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
 (TensorSpec(shape=(), dtype=tf.float32, name=None),
  TensorSpec(shape=(100,), dtype=tf.int32, name=None)))
# Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))

dataset4.element_spec
SparseTensorSpec(TensorShape([3, 4]), tf.int32)
# Use value_type to see the type of value represented by the element spec
dataset4.element_spec.value_type
tensorflow.python.framework.sparse_tensor.SparseTensor

Dataset ชุดข้อมูลสนับสนุนการเปลี่ยนแปลงของโครงสร้างใด ๆ เมื่อใช้การ Dataset.map() และ Dataset.filter() ซึ่งใช้ฟังก์ชันกับแต่ละองค์ประกอบโครงสร้างองค์ประกอบจะกำหนดอาร์กิวเมนต์ของฟังก์ชัน:

dataset1 = tf.data.Dataset.from_tensor_slices(
    tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))

dataset1
<TensorSliceDataset shapes: (10,), types: tf.int32>
for z in dataset1:
  print(z.numpy())
[1 2 5 3 8 9 6 7 9 3]
[5 4 3 3 9 7 3 1 5 4]
[7 3 4 8 4 8 5 9 4 1]
[8 6 8 2 6 3 5 8 3 3]

dataset2 = tf.data.Dataset.from_tensor_slices(
   (tf.random.uniform([4]),
    tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))

dataset2
<TensorSliceDataset shapes: ((), (100,)), types: (tf.float32, tf.int32)>
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

dataset3
<ZipDataset shapes: ((10,), ((), (100,))), types: (tf.int32, (tf.float32, tf.int32))>
for a, (b,c) in dataset3:
  print('shapes: {a.shape}, {b.shape}, {c.shape}'.format(a=a, b=b, c=c))
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)

การอ่านข้อมูลอินพุต

การใช้อาร์เรย์ NumPy

ดูการ โหลดอาร์เรย์ NumPy สำหรับตัวอย่างเพิ่มเติม

หากข้อมูลอินพุตทั้งหมดของคุณพอดีกับหน่วยความจำวิธีที่ง่ายที่สุดในการสร้าง Dataset จากพวกเขาคือการแปลงเป็นวัตถุ tf.Tensor และใช้ Dataset.from_tensor_slices()

train, test = tf.keras.datasets.fashion_mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 1s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

images, labels = train
images = images/255

dataset = tf.data.Dataset.from_tensor_slices((images, labels))
dataset
<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.float64, tf.uint8)>

การใช้เครื่องกำเนิด Python

แหล่งข้อมูลทั่วไปอีกแหล่งที่สามารถนำเข้าได้อย่างง่ายดายเป็น tf.data.Dataset คือตัวสร้าง python

def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1
for n in count(5):
  print(n)
0
1
2
3
4

ตัวสร้าง Dataset.from_generator แปลงตัวสร้าง python เป็น tf.data.Dataset ทำงานได้อย่างสมบูรณ์

ตัวสร้างจะเรียกได้ว่าเป็นอินพุตไม่ใช่ตัวสร้างซ้ำ สิ่งนี้ช่วยให้สามารถรีสตาร์ทเครื่องกำเนิดไฟฟ้าเมื่อถึงจุดสิ้นสุด ใช้ args ซึ่งเป็นทางเลือกซึ่งส่งผ่านเป็นอาร์กิวเมนต์ของ callable

จำเป็นต้องใช้อาร์กิวเมนต์ output_types เนื่องจาก tf.data สร้าง tf.Graph ภายในและขอบกราฟต้องการ tf.dtype

ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):
  print(count_batch.numpy())
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]

ไม่ จำเป็นต้องใช้ อาร์กิวเมนต์ output_shapes แต่ขอแนะนำอย่างมากเนื่องจากการดำเนินการเทนเซอร์โฟลว์จำนวนมากไม่รองรับเทนเซอร์ที่มีอันดับที่ไม่รู้จัก หากไม่ทราบความยาวของแกนเฉพาะหรือตัวแปรให้ตั้งค่าเป็น None ใน output_shapes

สิ่งสำคัญคือต้องสังเกตว่า output_shapes และ output_types ไปตามกฎการซ้อนกันเหมือนกับวิธีการชุดข้อมูลอื่น ๆ

นี่คือตัวสร้างตัวอย่างที่แสดงให้เห็นทั้งสองด้านโดยจะส่งคืนค่าทูเปิลของอาร์เรย์โดยที่อาร์เรย์ที่สองเป็นเวกเตอร์ที่ไม่ทราบความยาว

def gen_series():
  i = 0
  while True:
    size = np.random.randint(0, 10)
    yield i, np.random.normal(size=(size,))
    i += 1
for i, series in gen_series():
  print(i, ":", str(series))
  if i > 5:
    break
0 : [1.2226]
1 : [ 0.4785  1.1887 -0.2828  0.6047  1.3367 -0.4387  0.1822]
2 : [ 1.1343 -0.2676  0.0224 -0.111  -0.1384 -1.9315]
3 : [-0.8651]
4 : [0.6275]
5 : [ 0.8034  2.0773  0.6183 -0.3746]
6 : [-0.9439 -0.6686]

เอาต์พุตแรกคือ int32 ส่วนที่สองคือ float32

รายการแรกคือสเกลาร์รูปร่าง () และรายการที่สองคือเวกเตอร์ที่ไม่ทราบความยาวรูปร่าง (None,)

ds_series = tf.data.Dataset.from_generator(
    gen_series, 
    output_types=(tf.int32, tf.float32), 
    output_shapes=((), (None,)))

ds_series
<FlatMapDataset shapes: ((), (None,)), types: (tf.int32, tf.float32)>

ตอนนี้สามารถใช้งานได้เหมือน tf.data.Dataset ทั่วไป โปรดทราบว่าเมื่อรวมชุดข้อมูลที่มีรูปร่างตัวแปรคุณต้องใช้ Dataset.padded_batch

ds_series_batch = ds_series.shuffle(20).padded_batch(10)

ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequence_batch.numpy())
[18  1  8 22  5 13  6  2 14 28]

[[ 0.0196 -1.007   0.1843  0.0289  0.0735 -0.6279 -1.0877  0.    ]
 [ 0.8238 -1.0284  0.      0.      0.      0.      0.      0.    ]
 [ 0.544  -1.1061  1.2368  0.3975  0.      0.      0.      0.    ]
 [-1.1933 -0.7535  1.0497  0.1764  1.5319  0.9202  0.4027  0.6844]
 [-1.2025 -2.7148  1.0702  1.3893  0.      0.      0.      0.    ]
 [ 1.3787  0.6817  0.1197 -0.1178  0.9764  0.8895  0.      0.    ]
 [-1.523  -0.7722 -2.13    0.2761 -1.1094  0.      0.      0.    ]
 [ 0.203   2.1858 -0.722   1.2554  1.2208  0.1813 -2.3427  0.    ]
 [ 0.7685  1.7138 -1.2376 -2.6168  0.2565  0.0753  1.5653  0.    ]
 [ 0.0908 -0.4535  0.1257 -0.5122  0.      0.      0.      0.    ]]

สำหรับตัวอย่างที่เป็นจริงมากขึ้นให้ลองตัด preprocessing.image.ImageDataGenerator เป็น tf.data.Dataset

ขั้นแรกดาวน์โหลดข้อมูล:

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 9s 0us/step

สร้าง image.ImageDataGenerator

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
images, labels = next(img_gen.flow_from_directory(flowers))
Found 3670 images belonging to 5 classes.

print(images.dtype, images.shape)
print(labels.dtype, labels.shape)
float32 (32, 256, 256, 3)
float32 (32, 5)

ds = tf.data.Dataset.from_generator(
    lambda: img_gen.flow_from_directory(flowers), 
    output_types=(tf.float32, tf.float32), 
    output_shapes=([32,256,256,3], [32,5])
)

ds.element_spec
(TensorSpec(shape=(32, 256, 256, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(32, 5), dtype=tf.float32, name=None))
for images, label in ds.take(1):
  print('images.shape: ', images.shape)
  print('labels.shape: ', labels.shape)

Found 3670 images belonging to 5 classes.
images.shape:  (32, 256, 256, 3)
labels.shape:  (32, 5)

การใช้ข้อมูล TFRecord

ดูการ โหลด TFRecords สำหรับตัวอย่าง end-to-end

tf.data API รองรับไฟล์รูปแบบต่างๆเพื่อให้คุณประมวลผลชุดข้อมูลขนาดใหญ่ที่ไม่พอดีกับหน่วยความจำได้ ตัวอย่างเช่นรูปแบบไฟล์ TFRecord เป็นรูปแบบไบนารีเชิงบันทึกที่เรียบง่ายซึ่งแอปพลิเคชัน TensorFlow จำนวนมากใช้สำหรับข้อมูลการฝึกอบรม คลาส tf.data.TFRecordDataset ช่วยให้คุณสามารถสตรีมผ่านเนื้อหาของไฟล์ TFRecord ตั้งแต่หนึ่งไฟล์ขึ้นไปโดยเป็นส่วนหนึ่งของไปป์ไลน์อินพุต

นี่คือตัวอย่างการใช้ไฟล์ทดสอบจาก French Street Name Signs (FSNS)

# Creates a dataset that reads all of the examples from two files.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
7905280/7904079 [==============================] - 0s 0us/step

อาร์กิวเมนต์ filenames สำหรับตัวเริ่มต้น TFRecordDataset อาจเป็นสตริงรายการสตริงหรือ tf.Tensor ของสตริงก็ได้ ดังนั้นหากคุณมีไฟล์สองชุดสำหรับวัตถุประสงค์ในการฝึกอบรมและการตรวจสอบความถูกต้องคุณสามารถสร้างเมธอดจากโรงงานที่สร้างชุดข้อมูลโดยใช้ชื่อไฟล์เป็นอาร์กิวเมนต์อินพุต:

dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>

โครงการ TensorFlow จำนวนมากใช้ระเบียน tf.train.Example ในไฟล์ TFRecord ที่เป็นอนุกรม สิ่งเหล่านี้ต้องได้รับการถอดรหัสก่อนจึงจะตรวจสอบได้:

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

parsed.features.feature['image/text']
bytes_list {
  value: "Rue Perreyon"
}

การใช้ข้อมูลข้อความ

ดูการ โหลดข้อความ สำหรับตัวอย่าง end to end

ชุดข้อมูลจำนวนมากถูกแจกจ่ายเป็นไฟล์ข้อความอย่างน้อยหนึ่งไฟล์ tf.data.TextLineDataset มีวิธีง่ายๆในการแยกบรรทัดจากไฟล์ข้อความอย่างน้อยหนึ่งไฟล์ ด้วยชื่อไฟล์อย่างน้อยหนึ่งชื่อ TextLineDataset จะสร้างองค์ประกอบที่มีค่าสตริงหนึ่งรายการต่อบรรทัดของไฟล์เหล่านั้น

directory_url = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
file_names = ['cowper.txt', 'derby.txt', 'butler.txt']

file_paths = [
    tf.keras.utils.get_file(file_name, directory_url + file_name)
    for file_name in file_names
]
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt
819200/815980 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt
811008/809730 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt
811008/807992 [==============================] - 0s 0us/step

dataset = tf.data.TextLineDataset(file_paths)

นี่คือสองสามบรรทัดแรกของไฟล์แรก:

for line in dataset.take(5):
  print(line.numpy())
b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b'His wrath pernicious, who ten thousand woes'
b"Caused to Achaia's host, sent many a soul"
b'Illustrious into Ades premature,'
b'And Heroes gave (so stood the will of Jove)'

หากต้องการสลับบรรทัดระหว่างไฟล์ให้ใช้ Dataset.interleave ทำให้ง่ายต่อการสลับไฟล์ร่วมกัน บรรทัดแรกบรรทัดที่สองและสามจากการแปลแต่ละครั้งมีดังนี้

files_ds = tf.data.Dataset.from_tensor_slices(file_paths)
lines_ds = files_ds.interleave(tf.data.TextLineDataset, cycle_length=3)

for i, line in enumerate(lines_ds.take(9)):
  if i % 3 == 0:
    print()
  print(line.numpy())

b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus' son;"
b"\xef\xbb\xbfOf Peleus' son, Achilles, sing, O Muse,"
b'\xef\xbb\xbfSing, O goddess, the anger of Achilles son of Peleus, that brought'

b'His wrath pernicious, who ten thousand woes'
b'The vengeance, deep and deadly; whence to Greece'
b'countless ills upon the Achaeans. Many a brave soul did it send'

b"Caused to Achaia's host, sent many a soul"
b'Unnumbered ills arose; which many a soul'
b'hurrying down to Hades, and many a hero did it yield a prey to dogs and'

ตามค่าเริ่มต้น TextLineDataset จะให้ ทุก บรรทัดของแต่ละไฟล์ซึ่งอาจไม่เป็นที่ต้องการตัวอย่างเช่นถ้าไฟล์เริ่มต้นด้วยบรรทัดส่วนหัวหรือมีข้อคิดเห็น เส้นเหล่านี้สามารถลบออกได้โดยใช้การ Dataset.skip() หรือ Dataset.filter() ที่นี่คุณข้ามบรรทัดแรกจากนั้นกรองเพื่อค้นหาผู้รอดชีวิตเท่านั้น

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step

for line in titanic_lines.take(10):
  print(line.numpy())
b'survived,sex,age,n_siblings_spouses,parch,fare,class,deck,embark_town,alone'
b'0,male,22.0,1,0,7.25,Third,unknown,Southampton,n'
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'0,male,28.0,0,0,8.4583,Third,unknown,Queenstown,y'
b'0,male,2.0,3,1,21.075,Third,unknown,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'

def survived(line):
  return tf.not_equal(tf.strings.substr(line, 0, 1), "0")

survivors = titanic_lines.skip(1).filter(survived)
for line in survivors.take(10):
  print(line.numpy())
b'1,female,38.0,1,0,71.2833,First,C,Cherbourg,n'
b'1,female,26.0,0,0,7.925,Third,unknown,Southampton,y'
b'1,female,35.0,1,0,53.1,First,C,Southampton,n'
b'1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n'
b'1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n'
b'1,female,4.0,1,1,16.7,Third,G,Southampton,n'
b'1,male,28.0,0,0,13.0,Second,unknown,Southampton,y'
b'1,female,28.0,0,0,7.225,Third,unknown,Cherbourg,y'
b'1,male,28.0,0,0,35.5,First,A,Southampton,y'
b'1,female,38.0,1,5,31.3875,Third,unknown,Southampton,n'

การใช้ข้อมูล CSV

ดูการ โหลดไฟล์ CSV และการ โหลด Pandas DataFrames สำหรับตัวอย่างเพิ่มเติม

รูปแบบไฟล์ CSV เป็นรูปแบบยอดนิยมสำหรับการจัดเก็บข้อมูลแบบตารางเป็นข้อความธรรมดา

ตัวอย่างเช่น:

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
df = pd.read_csv(titanic_file, index_col=None)
df.head()

หากข้อมูลของคุณพอดีกับหน่วยความจำวิธี Dataset.from_tensor_slices เดียวกันจะทำงานบนพจนานุกรมทำให้สามารถนำเข้าข้อมูลนี้ได้อย่างง่ายดาย:

titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))

for feature_batch in titanic_slices.take(1):
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
  'survived'          : 0
  'sex'               : b'male'
  'age'               : 22.0
  'n_siblings_spouses': 1
  'parch'             : 0
  'fare'              : 7.25
  'class'             : b'Third'
  'deck'              : b'unknown'
  'embark_town'       : b'Southampton'
  'alone'             : b'n'

แนวทางที่ปรับขนาดได้มากขึ้นคือการโหลดจากดิสก์เท่าที่จำเป็น

โมดูล tf.data จัดเตรียมวิธีการแยกระเบียนจากไฟล์ CSV ตั้งแต่หนึ่งไฟล์ขึ้นไปที่สอดคล้องกับ RFC 4180

experimental.make_csv_dataset ฟังก์ชั่นคือการติดต่อระดับสูงสำหรับการอ่านชุดไฟล์ CSV รองรับการอนุมานประเภทคอลัมน์และคุณสมบัติอื่น ๆ อีกมากมายเช่นการแบทช์และการสับเพื่อให้การใช้งานง่ายขึ้น

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived")
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  print("features:")
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived': [0 0 0 0]
features:
  'sex'               : [b'male' b'male' b'male' b'male']
  'age'               : [18. 19. 31. 34.]
  'n_siblings_spouses': [0 0 0 1]
  'parch'             : [0 0 0 0]
  'fare'              : [ 8.3    8.05   7.775 26.   ]
  'class'             : [b'Third' b'Third' b'Third' b'Second']
  'deck'              : [b'unknown' b'unknown' b'unknown' b'unknown']
  'embark_town'       : [b'Southampton' b'Southampton' b'Southampton' b'Southampton']
  'alone'             : [b'y' b'y' b'y' b'n']

คุณสามารถใช้อาร์กิวเมนต์ select_columns หากคุณต้องการเพียงคอลัมน์ย่อย

titanic_batches = tf.data.experimental.make_csv_dataset(
    titanic_file, batch_size=4,
    label_name="survived", select_columns=['class', 'fare', 'survived'])
for feature_batch, label_batch in titanic_batches.take(1):
  print("'survived': {}".format(label_batch))
  for key, value in feature_batch.items():
    print("  {!r:20s}: {}".format(key, value))
'survived': [1 0 0 1]
  'fare'              : [ 20.25     9.     110.8833  59.4   ]
  'class'             : [b'Third' b'Third' b'First' b'First']

นอกจากนี้ยังมี experimental.CsvDataset ระดับล่างคลาส CSvDataset ซึ่งให้การควบคุมที่ละเอียดยิ่งขึ้น ไม่สนับสนุนการอนุมานประเภทคอลัมน์ คุณต้องระบุประเภทของแต่ละคอลัมน์แทน

titanic_types  = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string, tf.string] 
dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types , header=True)

for line in dataset.take(10):
  print([item.numpy() for item in line])
[0, b'male', 22.0, 1, 0, 7.25, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 38.0, 1, 0, 71.2833, b'First', b'C', b'Cherbourg', b'n']
[1, b'female', 26.0, 0, 0, 7.925, b'Third', b'unknown', b'Southampton', b'y']
[1, b'female', 35.0, 1, 0, 53.1, b'First', b'C', b'Southampton', b'n']
[0, b'male', 28.0, 0, 0, 8.4583, b'Third', b'unknown', b'Queenstown', b'y']
[0, b'male', 2.0, 3, 1, 21.075, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 27.0, 0, 2, 11.1333, b'Third', b'unknown', b'Southampton', b'n']
[1, b'female', 14.0, 1, 0, 30.0708, b'Second', b'unknown', b'Cherbourg', b'n']
[1, b'female', 4.0, 1, 1, 16.7, b'Third', b'G', b'Southampton', b'n']
[0, b'male', 20.0, 0, 0, 8.05, b'Third', b'unknown', b'Southampton', b'y']

หากบางคอลัมน์ว่างเปล่าอินเทอร์เฟซระดับต่ำนี้จะช่วยให้คุณระบุค่าเริ่มต้นแทนประเภทคอลัมน์

%%writefile missing.csv
1,2,3,4
,2,3,4
1,,3,4
1,2,,4
1,2,3,
,,,
Writing missing.csv

# Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values.

record_defaults = [999,999,999,999]
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults)
dataset = dataset.map(lambda *items: tf.stack(items))
dataset
<MapDataset shapes: (4,), types: tf.int32>
for line in dataset:
  print(line.numpy())
[1 2 3 4]
[999   2   3   4]
[  1 999   3   4]
[  1   2 999   4]
[  1   2   3 999]
[999 999 999 999]

ตามค่าเริ่มต้นชุดข้อมูล CsvDataset จะให้ ทุก คอลัมน์ของ ทุก บรรทัดของไฟล์ซึ่งอาจไม่เป็นที่ต้องการตัวอย่างเช่นหากไฟล์เริ่มต้นด้วยบรรทัดส่วนหัวที่ควรละเว้นหรือหากบางคอลัมน์ไม่จำเป็นต้องใช้ในอินพุต บรรทัดและฟิลด์เหล่านี้สามารถลบออกได้โดยใช้อาร์กิวเมนต์ header และ select_cols ตามลำดับ

# Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [999, 999] # Only provide defaults for the selected columns
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults, select_cols=[1, 3])
dataset = dataset.map(lambda *items: tf.stack(items))
dataset
<MapDataset shapes: (2,), types: tf.int32>
for line in dataset:
  print(line.numpy())
[2 4]
[2 4]
[999   4]
[2 4]
[  2 999]
[999 999]

การใช้ชุดไฟล์

มีชุดข้อมูลจำนวนมากที่แจกจ่ายเป็นชุดไฟล์โดยแต่ละไฟล์เป็นตัวอย่าง

flowers_root = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
flowers_root = pathlib.Path(flowers_root)

ไดเร็กทอรีรูทมีไดเร็กทอรีสำหรับแต่ละคลาส:

for item in flowers_root.glob("*"):
  print(item.name)
sunflowers
daisy
LICENSE.txt
roses
tulips
dandelion

ไฟล์ในแต่ละคลาสไดเร็กทอรีเป็นตัวอย่าง:

list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

for f in list_ds.take(5):
  print(f.numpy())
b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/4933229095_f7e4218b28.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/dandelion/18282528206_7fb3166041.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/daisy/3711723108_65247a3170.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/sunflowers/4019748730_ee09b39a43.jpg'
b'/home/kbuilder/.keras/datasets/flower_photos/roses/475936554_a2b38aaa8e.jpg'

อ่านข้อมูลโดยใช้ฟังก์ชัน tf.io.read_file และแยกฉลากออกจากพา ธ ส่งคืนคู่ (image, label) :

def process_path(file_path):
  label = tf.strings.split(file_path, os.sep)[-2]
  return tf.io.read_file(file_path), label

labeled_ds = list_ds.map(process_path)
for image_raw, label_text in labeled_ds.take(1):
  print(repr(image_raw.numpy()[:100]))
  print()
  print(label_text.numpy())
b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xe2\x0cXICC_PROFILE\x00\x01\x01\x00\x00\x0cHLino\x02\x10\x00\x00mntrRGB XYZ \x07\xce\x00\x02\x00\t\x00\x06\x001\x00\x00acspMSFT\x00\x00\x00\x00IEC sRGB\x00\x00\x00\x00\x00\x00'

b'roses'

องค์ประกอบชุดข้อมูลแบบแบตช์

การผสมแบบง่าย

รูปแบบที่ง่ายที่สุดของ batching stacks n องค์ประกอบที่ต่อเนื่องกันของชุดข้อมูลเป็นองค์ประกอบเดียว การ Dataset.batch() ทำเช่นนี้โดยมีข้อ จำกัด เดียวกันกับตัวดำเนินการ tf.stack() ใช้กับแต่ละองค์ประกอบขององค์ประกอบนั่นคือสำหรับแต่ละองค์ประกอบ i องค์ประกอบทั้งหมดต้องมีค่าเทนเซอร์ที่มีรูปร่างเหมือนกัน

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

for batch in batched_dataset.take(4):
  print([arr.numpy() for arr in batch])
[array([0, 1, 2, 3]), array([ 0, -1, -2, -3])]
[array([4, 5, 6, 7]), array([-4, -5, -6, -7])]
[array([ 8,  9, 10, 11]), array([ -8,  -9, -10, -11])]
[array([12, 13, 14, 15]), array([-12, -13, -14, -15])]

ในขณะที่ tf.data พยายามเผยแพร่ข้อมูลรูปร่างการตั้งค่าเริ่มต้นของ Dataset.batch ส่งผลให้มีขนาดชุดงานที่ไม่รู้จักเนื่องจากชุดงานสุดท้ายอาจไม่เต็ม สังเกตว่า None ในรูปร่าง:

batched_dataset
<BatchDataset shapes: ((None,), (None,)), types: (tf.int64, tf.int64)>

ใช้อาร์กิวเมนต์ drop_remainder เพื่อละเว้นชุดสุดท้ายและรับการขยายรูปร่างแบบเต็ม:

batched_dataset = dataset.batch(7, drop_remainder=True)
batched_dataset
<BatchDataset shapes: ((7,), (7,)), types: (tf.int64, tf.int64)>

การรวมเทนเซอร์พร้อมช่องว่างภายใน

สูตรข้างต้นใช้ได้กับเทนเซอร์ที่มีขนาดเท่ากันทั้งหมด อย่างไรก็ตามโมเดลจำนวนมาก (เช่นแบบจำลองลำดับ) ทำงานกับข้อมูลอินพุตที่มีขนาดแตกต่างกัน (เช่นลำดับความยาวต่างกัน) ในการจัดการกรณีนี้การแปลง Dataset.padded_batch ช่วยให้คุณสามารถแบตช์เทนเซอร์ที่มีรูปร่างต่างกันได้โดยการระบุมิติข้อมูลอย่างน้อยหนึ่งมิติซึ่งอาจมีการบุนวม

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=(None,))

for batch in dataset.take(2):
  print(batch.numpy())
  print()

[[0 0 0]
 [1 0 0]
 [2 2 0]
 [3 3 3]]

[[4 4 4 4 0 0 0]
 [5 5 5 5 5 0 0]
 [6 6 6 6 6 6 0]
 [7 7 7 7 7 7 7]]


การแปลง Dataset.padded_batch ช่วยให้คุณตั้งค่าช่องว่างภายในที่แตกต่างกันสำหรับแต่ละมิติของแต่ละองค์ประกอบและอาจเป็นความยาวผันแปร (แสดงโดย None ในตัวอย่างด้านบน) หรือความยาวคงที่ นอกจากนี้ยังสามารถแทนที่ค่าช่องว่างภายในซึ่งมีค่าเริ่มต้นเป็น 0

เวิร์กโฟลว์การฝึกอบรม

การประมวลผลหลายยุค

tf.data API นำเสนอสองวิธีหลักในการประมวลผลหลายยุคของข้อมูลเดียวกัน

วิธีที่ง่ายที่สุดในการวนซ้ำชุดข้อมูลในหลายยุคคือการใช้การ Dataset.repeat() ขั้นแรกสร้างชุดข้อมูลไททานิก:

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
def plot_batch_sizes(ds):
  batch_sizes = [batch.shape[0] for batch in ds]
  plt.bar(range(len(batch_sizes)), batch_sizes)
  plt.xlabel('Batch number')
  plt.ylabel('Batch size')

การใช้การ Dataset.repeat() โดยไม่มีอาร์กิวเมนต์จะเป็นการป้อนข้อมูลซ้ำไปเรื่อย ๆ

การแปลง Dataset.repeat เชื่อมต่ออาร์กิวเมนต์โดยไม่ส่งสัญญาณการสิ้นสุดของยุคหนึ่งและจุดเริ่มต้นของยุคถัดไป ด้วยเหตุนี้ Dataset.batch ใช้ Dataset.batch หลังจาก Dataset.repeat จะให้ชุดงานที่คร่อมขอบเขตยุค:

titanic_batches = titanic_lines.repeat(3).batch(128)
plot_batch_sizes(titanic_batches)

png

หากคุณต้องการการแยกยุคที่ชัดเจนให้ใส่ Dataset.batch ก่อนทำซ้ำ:

titanic_batches = titanic_lines.batch(128).repeat(3)

plot_batch_sizes(titanic_batches)

png

หากคุณต้องการทำการคำนวณแบบกำหนดเอง (เช่นเพื่อรวบรวมสถิติ) ในตอนท้ายของแต่ละยุคการเริ่มต้นการทำซ้ำชุดข้อมูลในแต่ละยุคจะทำได้ง่ายที่สุด:

epochs = 3
dataset = titanic_lines.batch(128)

for epoch in range(epochs):
  for batch in dataset:
    print(batch.shape)
  print("End of epoch: ", epoch)
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  0
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  1
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch:  2

สุ่มสับข้อมูลอินพุต

การ Dataset.shuffle() รักษาบัฟเฟอร์ขนาดคงที่และเลือกองค์ประกอบถัดไปแบบสุ่มจากบัฟเฟอร์นั้น

เพิ่มดัชนีลงในชุดข้อมูลเพื่อให้คุณเห็นผล:

lines = tf.data.TextLineDataset(titanic_file)
counter = tf.data.experimental.Counter()

dataset = tf.data.Dataset.zip((counter, lines))
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(20)
dataset
<BatchDataset shapes: ((None,), (None,)), types: (tf.int64, tf.string)>

เนื่องจาก buffer_size คือ 100 และขนาดแบตช์คือ 20 ชุดงานแรกจึงไม่มีองค์ประกอบที่มีดัชนีมากกว่า 120

n,line_batch = next(iter(dataset))
print(n.numpy())
[ 29  39  94  10  81  85  74   3  91  53  87  17   1  64  54 107  20  63
 116  62]

เช่นเดียวกับ Dataset.batch ลำดับที่สัมพันธ์กับ Dataset.repeat สำคัญ

Dataset.shuffle ไม่ส่งสัญญาณการสิ้นสุดยุคจนกว่าบัฟเฟอร์การสุ่มจะว่างเปล่า ดังนั้นการสับเปลี่ยนที่วางไว้ก่อนการเล่นซ้ำจะแสดงทุกองค์ประกอบของยุคหนึ่งก่อนที่จะย้ายไปที่ถัดไป:

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.shuffle(buffer_size=100).batch(10).repeat(2)

print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(60).take(5):
  print(n.numpy())
Here are the item ID's near the epoch boundary:

[469 560 614 431 557 558 618 616 571 532]
[464 403 509 589 365 610 596 428 600 605]
[539 467 163 599 624 545 552 548]
[92 30 73 51 72 71 56 83 32 46]
[ 74 106  75  78  26  21   3  54  80  50]

shuffle_repeat = [n.numpy().mean() for n, line_batch in shuffled]
plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.ylabel("Mean item ID")
plt.legend()
<matplotlib.legend.Legend at 0x7f79e43368d0>

png

แต่การทำซ้ำก่อนการสุ่มจะผสมผสานขอบเขตยุคเข้าด้วยกัน:

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.repeat(2).shuffle(buffer_size=100).batch(10)

print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(55).take(15):
  print(n.numpy())
Here are the item ID's near the epoch boundary:

[618 517  22  13 602 608  25 621  12 579]
[568 620 534 547 319   1  17 346  29 581]
[537 395 560  33 617 557 561 610 585 627]
[428 607  32   0 455 496 605 624   6 596]
[511 588  14 572   4  21  66 471  41  36]
[583 622  49  54 616  75  65  57  76   7]
[ 77  37 574 586  74  56  28 419  73   5]
[599  83  31  61  91 481 577  98  26  50]
[  2 625  85  23 606 102 550  20 612  15]
[ 24  63 592 615 388  39  45  51  68  97]
[ 46  34  67  55  79  94 548  47  72 436]
[127  89  30  90 126 601 360  87 104  52]
[ 59 130 106 584 118 265  48  70 597 121]
[626 105 135  10  38  53 132 136 101   9]
[512 542 109 154 108  42 114  11  58 166]

repeat_shuffle = [n.numpy().mean() for n, line_batch in shuffled]

plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.plot(repeat_shuffle, label="repeat().shuffle()")
plt.ylabel("Mean item ID")
plt.legend()
<matplotlib.legend.Legend at 0x7f79e42b60b8>

png

การประมวลผลข้อมูลล่วงหน้า

การ Dataset.map(f) สร้างชุดข้อมูลใหม่โดยใช้ฟังก์ชันที่กำหนด f กับแต่ละองค์ประกอบของชุดข้อมูลอินพุต มันขึ้นอยู่กับฟังก์ชัน map() ที่มักใช้กับรายการ (และโครงสร้างอื่น ๆ ) ในภาษาโปรแกรมการทำงาน ฟังก์ชัน f รับวัตถุ tf.Tensor ที่แสดงถึงองค์ประกอบเดียวในอินพุตและส่งคืน tf.Tensor ที่จะแทนองค์ประกอบเดียวในชุดข้อมูลใหม่ การนำไปใช้งานใช้การดำเนินการ TensorFlow มาตรฐานเพื่อเปลี่ยนองค์ประกอบหนึ่งเป็นอีกองค์ประกอบหนึ่ง

ส่วนนี้ครอบคลุมตัวอย่างทั่วไปของวิธีใช้ Dataset.map()

การถอดรหัสข้อมูลภาพและปรับขนาด

เมื่อฝึกอบรมเครือข่ายประสาทเทียมกับข้อมูลภาพในโลกแห่งความเป็นจริงมักจำเป็นต้องแปลงภาพที่มีขนาดแตกต่างกันเป็นขนาดทั่วไปเพื่อที่จะรวมเป็นขนาดคงที่

สร้างชุดข้อมูลชื่อไฟล์ดอกไม้ใหม่:

list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

เขียนฟังก์ชันที่จัดการกับองค์ประกอบชุดข้อมูล

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def parse_image(filename):
  parts = tf.strings.split(filename, os.sep)
  label = parts[-2]

  image = tf.io.read_file(filename)
  image = tf.image.decode_jpeg(image)
  image = tf.image.convert_image_dtype(image, tf.float32)
  image = tf.image.resize(image, [128, 128])
  return image, label

ทดสอบว่าใช้งานได้จริง

file_path = next(iter(list_ds))
image, label = parse_image(file_path)

def show(image, label):
  plt.figure()
  plt.imshow(image)
  plt.title(label.numpy().decode('utf-8'))
  plt.axis('off')

show(image, label)

png

แมปไว้เหนือชุดข้อมูล

images_ds = list_ds.map(parse_image)

for image, label in images_ds.take(2):
  show(image, label)

png

png

ใช้ตรรกะ Python โดยพลการ

ด้วยเหตุผลด้านประสิทธิภาพให้ใช้การดำเนินการ TensorFlow สำหรับการประมวลผลข้อมูลล่วงหน้าทุกครั้งที่ทำได้ อย่างไรก็ตามบางครั้งการเรียกไลบรารี Python ภายนอกก็มีประโยชน์เมื่อแยกวิเคราะห์ข้อมูลอินพุตของคุณ คุณสามารถใช้ tf.py_function() การดำเนินงานใน Dataset.map() การเปลี่ยนแปลง

ตัวอย่างเช่นหากคุณต้องการใช้การหมุนแบบสุ่มโมดูล tf.image จะมีเฉพาะ tf.image.rot90 ซึ่งไม่มีประโยชน์มากนักสำหรับการเพิ่มรูปภาพ

หากต้องการสาธิต tf.py_function ให้ลองใช้ฟังก์ชัน scipy.ndimage.rotate แทน:

import scipy.ndimage as ndimage

def random_rotate_image(image):
  image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)
  return image
image, label = next(iter(images_ds))
image = random_rotate_image(image)
show(image, label)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

ในการใช้ฟังก์ชันนี้กับ Dataset.map จะใช้คำเตือนเดียวกันกับ Dataset.from_generator คุณต้องอธิบายรูปร่างและประเภทที่ส่งคืนเมื่อคุณใช้ฟังก์ชัน:

def tf_random_rotate_image(image, label):
  im_shape = image.shape
  [image,] = tf.py_function(random_rotate_image, [image], [tf.float32])
  image.set_shape(im_shape)
  return image, label
rot_ds = images_ds.map(tf_random_rotate_image)

for image, label in rot_ds.take(2):
  show(image, label)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

png

การแยกวิเคราะห์ tf.Example ข้อความบัฟเฟอร์โปรโตคอล

ไปป์ไลน์อินพุตจำนวนมากดึงข้อความบัฟเฟอร์โปรโตคอล tf.train.Example จากรูปแบบ TFRecord แต่ละระเบียน tf.train.Example มี "คุณลักษณะ" อย่างน้อยหนึ่งรายการและโดยทั่วไปไปป์ไลน์อินพุตจะแปลงคุณลักษณะเหล่านี้เป็นเทนเซอร์

fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset
<TFRecordDatasetV2 shapes: (), types: tf.string>

คุณสามารถทำงานกับ tf.train.Example protos นอก tf.data.Dataset เพื่อทำความเข้าใจข้อมูล:

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

feature = parsed.features.feature
raw_img = feature['image/encoded'].bytes_list.value[0]
img = tf.image.decode_png(raw_img)
plt.imshow(img)
plt.axis('off')
_ = plt.title(feature["image/text"].bytes_list.value[0])

png

raw_example = next(iter(dataset))
def tf_parse(eg):
  example = tf.io.parse_example(
      eg[tf.newaxis], {
          'image/encoded': tf.io.FixedLenFeature(shape=(), dtype=tf.string),
          'image/text': tf.io.FixedLenFeature(shape=(), dtype=tf.string)
      })
  return example['image/encoded'][0], example['image/text'][0]
img, txt = tf_parse(raw_example)
print(txt.numpy())
print(repr(img.numpy()[:20]), "...")
b'Rue Perreyon'
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x02X' ...

decoded = dataset.map(tf_parse)
decoded
<MapDataset shapes: ((), ()), types: (tf.string, tf.string)>
image_batch, text_batch = next(iter(decoded.batch(10)))
image_batch.shape
TensorShape([10])

หน้าต่างอนุกรมเวลา

สำหรับตัวอย่างอนุกรมเวลาตั้งแต่ต้นจนจบโปรดดู: การ คาดการณ์อนุกรมเวลา

ข้อมูลอนุกรมเวลามักถูกจัดระเบียบโดยให้แกนเวลาเหมือนเดิม

ใช้ Dataset.range แบบง่ายเพื่อสาธิต:

range_ds = tf.data.Dataset.range(100000)

โดยปกติโมเดลที่ใช้ข้อมูลประเภทนี้จะต้องการส่วนเวลาที่ต่อเนื่องกัน

วิธีที่ง่ายที่สุดคือการจัดกลุ่มข้อมูล:

ใช้ batch

batches = range_ds.batch(10, drop_remainder=True)

for batch in batches.take(5):
  print(batch.numpy())
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]
[40 41 42 43 44 45 46 47 48 49]

หรือหากต้องการคาดการณ์อย่างหนาแน่นไปอีกขั้นหนึ่งคุณอาจเปลี่ยนคุณลักษณะและป้ายกำกับทีละขั้นตอนโดยสัมพันธ์กัน:

def dense_1_step(batch):
  # Shift features and labels one step relative to each other.
  return batch[:-1], batch[1:]

predict_dense_1_step = batches.map(dense_1_step)

for features, label in predict_dense_1_step.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8]  =>  [1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18]  =>  [11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28]  =>  [21 22 23 24 25 26 27 28 29]

ในการทำนายทั้งหน้าต่างแทนที่จะเป็นค่าชดเชยคงที่คุณสามารถแบ่งแบทช์ออกเป็นสองส่วน:

batches = range_ds.batch(15, drop_remainder=True)

def label_next_5_steps(batch):
  return (batch[:-5],   # Take the first 5 steps
          batch[-5:])   # take the remainder

predict_5_steps = batches.map(label_next_5_steps)

for features, label in predict_5_steps.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8 9]  =>  [10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]  =>  [25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]  =>  [40 41 42 43 44]

หากต้องการให้บางส่วนทับซ้อนกันระหว่างคุณลักษณะของชุดงานหนึ่งกับป้ายกำกับของอีกชุดหนึ่งให้ใช้ Dataset.zip :

feature_length = 10
label_length = 5

features = range_ds.batch(feature_length, drop_remainder=True)
labels = range_ds.batch(feature_length).skip(1).map(lambda labels: labels[:-5])

predict_5_steps = tf.data.Dataset.zip((features, labels))

for features, label in predict_5_steps.take(3):
  print(features.numpy(), " => ", label.numpy())
[0 1 2 3 4 5 6 7 8 9]  =>  [10 11 12 13 14]
[10 11 12 13 14 15 16 17 18 19]  =>  [20 21 22 23 24]
[20 21 22 23 24 25 26 27 28 29]  =>  [30 31 32 33 34]

ใช้ window

ในขณะที่ใช้ Dataset.batch มีสถานการณ์ที่คุณอาจต้องการการควบคุมที่ละเอียดขึ้น เมธอด Dataset.window ช่วยให้คุณควบคุมได้อย่างสมบูรณ์ แต่ต้องการการดูแลบางอย่าง: จะส่งคืน Dataset ของ Datasets ดูรายละเอียด โครงสร้างชุดข้อมูล

window_size = 5

windows = range_ds.window(window_size, shift=1)
for sub_ds in windows.take(5):
  print(sub_ds)
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>

เมธอด Dataset.flat_map สามารถใช้ชุดข้อมูลของชุดข้อมูลและทำให้แบนเป็นชุดข้อมูลเดียว:

 for x in windows.flat_map(lambda x: x).take(30):
   print(x.numpy(), end=' ')
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7f79e44309d8> and will run it as-is.
Cause: could not parse the source code:

for x in windows.flat_map(lambda x: x).take(30):

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function <lambda> at 0x7f79e44309d8> and will run it as-is.
Cause: could not parse the source code:

for x in windows.flat_map(lambda x: x).take(30):

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 3 4 5 6 7 4 5 6 7 8 5 6 7 8 9 

ในเกือบทุกกรณีคุณจะต้อง. .batch ชุดข้อมูลก่อน:

def sub_to_batch(sub):
  return sub.batch(window_size, drop_remainder=True)

for example in windows.flat_map(sub_to_batch).take(5):
  print(example.numpy())
[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]
[3 4 5 6 7]
[4 5 6 7 8]

ตอนนี้คุณจะเห็นว่าอาร์กิวเมนต์ shift ควบคุมว่าแต่ละหน้าต่างจะเลื่อนไปมากแค่ไหน

เมื่อรวมสิ่งนี้เข้าด้วยกันคุณอาจเขียนฟังก์ชันนี้:

def make_window_dataset(ds, window_size=5, shift=1, stride=1):
  windows = ds.window(window_size, shift=shift, stride=stride)

  def sub_to_batch(sub):
    return sub.batch(window_size, drop_remainder=True)

  windows = windows.flat_map(sub_to_batch)
  return windows

ds = make_window_dataset(range_ds, window_size=10, shift = 5, stride=3)

for example in ds.take(10):
  print(example.numpy())
[ 0  3  6  9 12 15 18 21 24 27]
[ 5  8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34 37]
[15 18 21 24 27 30 33 36 39 42]
[20 23 26 29 32 35 38 41 44 47]
[25 28 31 34 37 40 43 46 49 52]
[30 33 36 39 42 45 48 51 54 57]
[35 38 41 44 47 50 53 56 59 62]
[40 43 46 49 52 55 58 61 64 67]
[45 48 51 54 57 60 63 66 69 72]

จากนั้นจึงง่ายต่อการแยกฉลากเหมือนเมื่อก่อน:

dense_labels_ds = ds.map(dense_1_step)

for inputs,labels in dense_labels_ds.take(3):
  print(inputs.numpy(), "=>", labels.numpy())
[ 0  3  6  9 12 15 18 21 24] => [ 3  6  9 12 15 18 21 24 27]
[ 5  8 11 14 17 20 23 26 29] => [ 8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34] => [13 16 19 22 25 28 31 34 37]

การสุ่มตัวอย่างใหม่

เมื่อทำงานกับชุดข้อมูลที่มีความไม่สมดุลของคลาสมากคุณอาจต้องการสุ่มตัวอย่างชุดข้อมูลอีกครั้ง tf.data มีสองวิธีในการดำเนินการนี้ ชุดข้อมูลการฉ้อโกงบัตรเครดิตเป็นตัวอย่างที่ดีของปัญหาประเภทนี้

zip_path = tf.keras.utils.get_file(
    origin='https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip',
    fname='creditcard.zip',
    extract=True)

csv_path = zip_path.replace('.zip', '.csv')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip
69156864/69155632 [==============================] - 3s 0us/step

creditcard_ds = tf.data.experimental.make_csv_dataset(
    csv_path, batch_size=1024, label_name="Class",
    # Set the column types: 30 floats and an int.
    column_defaults=[float()]*30+[int()])

ตอนนี้ตรวจสอบการกระจายของคลาสมันเบ้มาก:

def count(counts, batch):
  features, labels = batch
  class_1 = labels == 1
  class_1 = tf.cast(class_1, tf.int32)

  class_0 = labels == 0
  class_0 = tf.cast(class_0, tf.int32)

  counts['class_0'] += tf.reduce_sum(class_0)
  counts['class_1'] += tf.reduce_sum(class_1)

  return counts
counts = creditcard_ds.take(10).reduce(
    initial_state={'class_0': 0, 'class_1': 0},
    reduce_func = count)

counts = np.array([counts['class_0'].numpy(),
                   counts['class_1'].numpy()]).astype(np.float32)

fractions = counts/counts.sum()
print(fractions)
[0.9957 0.0043]

แนวทางทั่วไปในการฝึกกับชุดข้อมูลที่ไม่สมดุลคือการปรับสมดุล tf.data มีวิธีการบางอย่างที่เปิดใช้งานเวิร์กโฟลว์นี้:

การสุ่มตัวอย่างชุดข้อมูล

แนวทางหนึ่งในการสุ่มตัวอย่างชุดข้อมูลใหม่คือการใช้ sample_from_datasets สิ่งนี้ใช้ได้มากกว่าเมื่อคุณมี data.Dataset แยกต่างหากชุด data.Dataset สำหรับแต่ละคลาส

ที่นี่เพียงใช้ตัวกรองเพื่อสร้างจากข้อมูลการฉ้อโกงบัตรเครดิต:

negative_ds = (
  creditcard_ds
    .unbatch()
    .filter(lambda features, label: label==0)
    .repeat())
positive_ds = (
  creditcard_ds
    .unbatch()
    .filter(lambda features, label: label==1)
    .repeat())
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7f79e43ce510> and will run it as-is.
Cause: could not parse the source code:

    .filter(lambda features, label: label==0)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function <lambda> at 0x7f79e43ce510> and will run it as-is.
Cause: could not parse the source code:

    .filter(lambda features, label: label==0)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7f79e43ceae8> and will run it as-is.
Cause: could not parse the source code:

    .filter(lambda features, label: label==1)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function <lambda> at 0x7f79e43ceae8> and will run it as-is.
Cause: could not parse the source code:

    .filter(lambda features, label: label==1)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

for features, label in positive_ds.batch(10).take(1):
  print(label.numpy())
[1 1 1 1 1 1 1 1 1 1]

ในการใช้ tf.data.experimental.sample_from_datasets ส่งชุดข้อมูลและน้ำหนักสำหรับแต่ละชุด:

balanced_ds = tf.data.experimental.sample_from_datasets(
    [negative_ds, positive_ds], [0.5, 0.5]).batch(10)

ตอนนี้ชุดข้อมูลจะสร้างตัวอย่างของแต่ละคลาสด้วยความน่าจะเป็น 50/50:

for features, labels in balanced_ds.take(10):
  print(labels.numpy())
[0 1 1 1 1 0 0 1 0 1]
[0 0 0 0 1 1 1 0 0 1]
[1 0 1 0 1 1 0 0 1 0]
[0 0 0 0 1 0 1 1 0 0]
[0 1 1 1 0 1 1 1 1 1]
[0 1 0 1 1 0 1 1 0 1]
[0 0 1 1 1 0 1 1 1 1]
[1 1 0 0 1 1 0 1 1 0]
[0 1 0 0 1 0 0 0 0 0]
[0 1 1 1 0 1 0 1 1 0]

การสุ่มตัวอย่างการปฏิเสธ

ปัญหาอย่างหนึ่งของวิธี experimental.sample_from_datasets ข้างต้น tf.data.Dataset คือต้องการชุดข้อมูล tf.data.Dataset แยกต่างหากต่อคลาส การใช้ Dataset.filter ใช้งานได้ แต่ส่งผลให้ข้อมูลทั้งหมดถูกโหลดสองครั้ง

ฟังก์ชัน data.experimental.rejection_resample สามารถนำไปใช้กับชุดข้อมูลเพื่อปรับสมดุลใหม่ได้ในขณะที่โหลดเพียงครั้งเดียว องค์ประกอบจะถูกทิ้งจากชุดข้อมูลเพื่อให้เกิดความสมดุล

data.experimental.rejection_resample ใช้อาร์กิวเมนต์ class_func class_func นี้ถูกนำไปใช้กับแต่ละองค์ประกอบของชุดข้อมูลและใช้เพื่อพิจารณาว่าคลาสใดเป็นของตัวอย่างเพื่อจุดประสงค์ในการปรับสมดุล

องค์ประกอบของ creditcard_ds มีคู่ (features, label) อยู่แล้ว ดังนั้น class_func ก็ต้องส่งคืนป้ายกำกับเหล่านั้น:

def class_func(features, label):
  return label

resampler ยังต้องการการกระจายเป้าหมายและการประมาณการการกระจายเริ่มต้นเป็นทางเลือก:

resampler = tf.data.experimental.rejection_resample(
    class_func, target_dist=[0.5, 0.5], initial_dist=fractions)

resampler เกี่ยวข้องกับแต่ละตัวอย่างดังนั้นคุณต้อง unbatch ชุดข้อมูลก่อนที่จะใช้ resampler:

resample_ds = creditcard_ds.unbatch().apply(resampler).batch(10)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/experimental/ops/resampling.py:156: Print (from tensorflow.python.ops.logging_ops) is deprecated and will be removed after 2018-08-20.
Instructions for updating:
Use tf.print instead of tf.Print. Note that tf.print returns a no-output operator that directly prints the output. Outside of defuns or eager mode, this operator will not be executed unless it is directly specified in session.run or used as a control dependency for other operators. This is only a concern in graph mode. Below is an example of how to ensure tf.print executes in graph mode:


resampler ส่งคืนสร้างคู่ (class, example) จากเอาต์พุตของ class_func ในกรณีนี้ example เป็นคู่ (feature, label) แล้วดังนั้นให้ใช้ map เพื่อวางสำเนาพิเศษของป้ายกำกับ:

balanced_ds = resample_ds.map(lambda extra_label, features_and_label: features_and_label)

ตอนนี้ชุดข้อมูลจะสร้างตัวอย่างของแต่ละคลาสด้วยความน่าจะเป็น 50/50:

for features, labels in balanced_ds.take(10):
  print(labels.numpy())
[1 0 1 1 0 0 0 0 0 0]
[1 0 1 0 1 1 0 1 0 1]
[0 1 1 0 0 1 0 0 0 0]
[1 0 1 0 0 1 0 1 0 0]
[1 1 1 1 0 0 0 0 0 1]
[1 0 1 1 1 0 1 0 1 1]
[0 0 1 1 1 1 0 0 1 1]
[0 0 1 0 0 1 1 1 0 1]
[1 0 0 0 1 0 0 0 0 0]
[0 0 0 0 1 0 0 0 1 1]

Iterator Checkpointing

Tensorflow รองรับ การตั้งจุดตรวจ เพื่อที่เมื่อกระบวนการฝึกของคุณเริ่มต้นใหม่จะสามารถกู้คืนด่านล่าสุดเพื่อกู้คืนความคืบหน้าส่วนใหญ่ได้ นอกจากการตรวจสอบตัวแปรแบบจำลองแล้วคุณยังสามารถตรวจสอบความคืบหน้าของตัววนซ้ำชุดข้อมูลได้อีกด้วย สิ่งนี้อาจมีประโยชน์หากคุณมีชุดข้อมูลขนาดใหญ่และไม่ต้องการเริ่มชุดข้อมูลตั้งแต่เริ่มต้นในการรีสตาร์ทแต่ละครั้ง อย่างไรก็ตามโปรดทราบว่าจุดตรวจตัววนซ้ำอาจมีขนาดใหญ่เนื่องจากการแปลงเช่นการ shuffle และการ prefetch ต้องใช้องค์ประกอบบัฟเฟอร์ภายในตัววนซ้ำ

ในการรวมตัววนซ้ำของคุณในจุดตรวจสอบให้ส่งตัววนซ้ำไปยังตัวสร้าง tf.train.Checkpoint

range_ds = tf.data.Dataset.range(20)

iterator = iter(range_ds)
ckpt = tf.train.Checkpoint(step=tf.Variable(0), iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, '/tmp/my_ckpt', max_to_keep=3)

print([next(iterator).numpy() for _ in range(5)])

save_path = manager.save()

print([next(iterator).numpy() for _ in range(5)])

ckpt.restore(manager.latest_checkpoint)

print([next(iterator).numpy() for _ in range(5)])
[0, 1, 2, 3, 4]
[5, 6, 7, 8, 9]
[5, 6, 7, 8, 9]

ใช้ tf.data กับ tf.keras

tf.keras API ช่วยลดความยุ่งยากในการสร้างและเรียกใช้โมเดลแมชชีนเลิร์นนิงในหลาย ๆ ด้าน .fit() และ. .evaluate() และ. .predict() APIs สนับสนุนชุดข้อมูลเป็นอินพุต นี่คือการตั้งค่าชุดข้อมูลและโมเดลอย่างรวดเร็ว:

train, test = tf.keras.datasets.fashion_mnist.load_data()

images, labels = train
images = images/255.0
labels = labels.astype(np.int32)
fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)

model = tf.keras.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
              metrics=['accuracy'])

การส่งผ่านชุดข้อมูลของคู่ (feature, label) เป็นสิ่งที่จำเป็นสำหรับ Model.fit และ Model.evaluate :

model.fit(fmnist_train_ds, epochs=2)
Epoch 1/2
WARNING:tensorflow:Layer flatten is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

1875/1875 [==============================] - 3s 2ms/step - loss: 0.5973 - accuracy: 0.7969
Epoch 2/2
1875/1875 [==============================] - 3s 2ms/step - loss: 0.4629 - accuracy: 0.8407

<tensorflow.python.keras.callbacks.History at 0x7f79e43d8710>

หากคุณส่งผ่านชุดข้อมูลที่ไม่มีที่สิ้นสุดตัวอย่างเช่นโดยการเรียก Dataset.repeat() คุณก็ต้องส่งอาร์กิวเมนต์ steps_per_epoch :

model.fit(fmnist_train_ds.repeat(), epochs=2, steps_per_epoch=20)
Epoch 1/2
20/20 [==============================] - 0s 2ms/step - loss: 0.4800 - accuracy: 0.8594
Epoch 2/2
20/20 [==============================] - 0s 2ms/step - loss: 0.4529 - accuracy: 0.8391

<tensorflow.python.keras.callbacks.History at 0x7f79e43d84a8>

สำหรับการประเมินคุณสามารถผ่านขั้นตอนการประเมินได้หลายขั้นตอน:

loss, accuracy = model.evaluate(fmnist_train_ds)
print("Loss :", loss)
print("Accuracy :", accuracy)
1875/1875 [==============================] - 3s 2ms/step - loss: 0.4386 - accuracy: 0.8508
Loss : 0.4385797381401062
Accuracy : 0.8507500290870667

สำหรับชุดข้อมูลแบบยาวให้กำหนดจำนวนขั้นตอนในการประเมิน:

loss, accuracy = model.evaluate(fmnist_train_ds.repeat(), steps=10)
print("Loss :", loss)
print("Accuracy :", accuracy)
10/10 [==============================] - 0s 2ms/step - loss: 0.3737 - accuracy: 0.8656
Loss : 0.37370163202285767
Accuracy : 0.8656250238418579

ไม่จำเป็นต้องใช้ป้ายกำกับเมื่อเรียก Model.predict

predict_ds = tf.data.Dataset.from_tensor_slices(images).batch(32)
result = model.predict(predict_ds, steps = 10)
print(result.shape)
(320, 10)

แต่ป้ายกำกับจะถูกละเว้นหากคุณส่งผ่านชุดข้อมูลที่มี:

result = model.predict(fmnist_train_ds, steps = 10)
print(result.shape)
(320, 10)