此页面由 Cloud Translation API 翻译。
Switch to English

使用TF Profiler分析tf.data性能

总览

本指南假定您熟悉TensorFlow Profilertf.data 。它旨在逐步提供示例指导,以帮助用户诊断和修复输入管道性能问题。

首先,收集您的TensorFlow作业的配置文件。有关如何执行此操作的说明,可用于CPU / GPUCloud TPU

TensorFlow Trace Viewer

下面详细介绍的分析工作流程着重于Profiler中的跟踪查看器工具。该工具显示时间轴,该时间轴显示TensorFlow程序执行的操作的持续时间,并允许您确定执行时间最长的操作。有关跟踪查看器的更多信息,请查看TF Profiler指南的此部分 。通常, tf.data事件将出现在主机CPU时间轴上。

分析工作流程

请遵循以下工作流程。如果您有反馈来帮助我们改进它,请创建一个标签为“ comp:data” 的github问题

1.您的tf.data管道生成数据的速度足够快吗?

首先确定输入管道是否是TensorFlow程序的瓶颈。

这样做,在跟踪查看器中查找IteratorGetNext::DoCompute ops。通常,您希望在步骤开始时就能看到这些内容。这些切片表示输入管道在被请求时产生一批元素所花费的时间。如果您正在使用keras或在tf.function遍历数据集,则应在tf_data_iterator_get_next线程中找到tf_data_iterator_get_next

请注意,如果您使用的是分销策略 ,你可能会看到IteratorGetNextAsOptional::DoCompute事件,而不是IteratorGetNext::DoCompute (如TF 2.3)。

image

如果呼叫迅速返回(<= 50 us),则意味着您的数据在需要时可用。输入管道不是您的瓶颈。有关更多常规性能分析提示,请参见Profiler指南

image

如果呼叫返回缓慢,则 tf.data无法满足使用者的请求。继续下一节。

2.您是否正在预取数据?

输入管道性能的最佳做法是在tf.data管道的末尾插入tf.data.Dataset.prefetch转换。此转换将输入管道的预处理计算与模型计算的下一步重叠,是训练模型时实现最佳输入管道性能所必需的。如果要预取数据,则应该在与IteratorGetNext::DoCompute op相同的线程上看到Iterator::Prefetch切片。

image

如果您在管道的末尾没有prefetch ,则应添加一个。有关tf.data性能建议的更多信息,请参阅《 tf.data性能指南》

如果您已经在预取数据 ,而输入管道仍然是您的瓶颈,请继续下一节以进一步分析性能。

3.您是否达到很高的CPU利用率?

tf.data通过尝试最大程度地利用可用资源来实现高吞吐量。通常,即使在GPU或TPU等加速器上运行模型时, tf.data管道也在CPU上运行。您可以使用sarhtop之类的工具来检查您的利用率,或者如果您在GCP上运行,则可以在云监视控制台中检查其利用率。

如果您的利用率低,则表明您的输入管道可能没有充分利用主机CPU。您应查阅tf.data性能指南以获取最佳做法。如果您应用了最佳实践,并且利用率和吞吐量仍然很低,请继续下面的瓶颈分析

如果您的利用率接近资源极限 ,为了进一步提高性能,则需要提高输入管道的效率(例如,避免不必要的计算)或卸载计算。

通过避免tf.data不必要的计算,可以提高输入管道的效率。一种方法是,如果您的数据适合内存,则在需要大量计算的工作之后插入tf.data.Dataset.cache转换。这以增加内存使用为代价减少了计算。此外,禁用tf.data中的运算tf.data内部并行tf.data有可能使效率提高> 10%,并且可以通过在输入管道上设置以下选项来完成:

 dataset = ...
options = tf.data.Options()
options.experimental_threading.max_intra_op_parallelism = 1
dataset = dataset.with_options(options)
 

4.瓶颈分析

下一节将tf.data介绍如何在跟踪查看器中读取tf.data事件,以了解瓶颈在哪里以及可能的缓解策略。

了解事件探查器中的tf.data事件

事件探查器中的每个tf.data事件都具有名称Iterator::<Dataset> ,其中<Dataset>是数据集源或转换的名称。每个事件还具有长名称Iterator::<Dataset_1>::...::<Dataset_n> ,您可以通过单击tf.data事件来看到它。在长名称中, <Dataset_n>与(短)名称中的<Dataset> <Dataset_n>匹配,长名称中的其他数据集表示下游转换。

image

例如,上面的屏幕截图是从以下代码生成的:

 dataset = tf.data.Dataset.range(10)
dataset = dataset.map(lambda x: x)
dataset = dataset.repeat(2)
dataset = dataset.batch(5)
 

在这里, Iterator::Map事件的长名称为Iterator::BatchV2::FiniteRepeat::Map 。请注意,数据集名称可能与python API略有不同(例如,使用FiniteRepeat而不是Repeat),但应足够直观以进行解析。

同步和异步转换

对于同步tf.data转换(例如BatchMap ),您将在同一线程上看到来自上游转换的事件。在上面的示例中,由于使用的所有转换都是同步的,因此所有事件都出现在同一线程上。

对于异步转换(例如PrefetchParallelMapParallelInterleaveMapAndBatch ),来自上游转换的事件将在不同的线程上。在这种情况下,“长名称”可以帮助您确定事件对应于管道中的哪个转换。

image

例如,上面的屏幕截图是从以下代码生成的:

 dataset = tf.data.Dataset.range(10)
dataset = dataset.map(lambda x: x)
dataset = dataset.repeat(2)
dataset = dataset.batch(5)
dataset = dataset.prefetch(1)
 

在这里, Iterator::Prefetch事件在tf_data_iterator_get_next线程上。由于Prefetch是异步的,因此其输入事件( BatchV2 )将位于不同的线程上,可以通过搜索长名称Iterator::Prefetch::BatchV2 。在这种情况下,它们位于tf_data_iterator_resource线程上。从其长名称中,您可以推断出BatchV2Prefetch上游。此外, BatchV2事件的parent_id将与Prefetch事件的ID匹配。

识别瓶颈

通常,要确定输入管道中的瓶颈,请将输入管道从最外面的转换一直到源代码。从管道中的最终转换开始,递归到上游转换,直到找到缓慢的转换或到达源数据集,例如TFRecord 。在上面的示例中,您将从Prefetch开始,然后向上游BatchV2FiniteRepeatMap ,最后是Range

通常,慢速转换对应于事件长而输入事件短的转换。下面是一些示例。

请注意,大多数主机输入管道中的最终(最外部)转换是Iterator::Model事件。模型转换是由tf.data运行时自动引入的,用于检测和自动调整输入管道的性能。

如果您的工作使用分配策略 ,则跟踪查看器将包含与设备输入管道相对应的其他事件。设备管道的最外层转换(位于IteratorGetNextOp::DoComputeIteratorGetNextAsOptionalOp::DoCompute )将是带有上游Iterator::Generator事件的Iterator::Prefetch事件。您可以通过搜索Iterator::Model事件找到相应的主机管道。

例子1

image

上面的屏幕截图是从以下输入管道生成的:

 dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.map(parse_record)
dataset = dataset.batch(32)
dataset = dataset.repeat()
 

在屏幕截图中,观察到(1) Iterator::Map事件很长,但是(2)其输入事件( Iterator::FlatMap )快速返回。这表明顺序Map转换是瓶颈。

请注意,在屏幕截图中, InstantiatedCapturedFunction::Run事件对应于执行map函数所花费的时间。

例子2

image

上面的屏幕截图是从以下输入管道生成的:

 dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.map(parse_record, num_parallel_calls=2)
dataset = dataset.batch(32)
dataset = dataset.repeat()
 

此示例与上面的示例相似,但是使用ParallelMap而不是Map。我们在这里注意到,(1) Iterator::ParallelMap事件很长,但是(2)它的输入事件Iterator::FlatMap (由于ParallelMap是异步的,它们在不同的线程上)很短。这表明ParallelMap转换是瓶颈。

解决瓶颈

源数据集

如果已将数据集源识别为瓶颈,例如从TFRecord文件中读取数据,则可以通过并行化数据提取来提高性能。为此,请确保将数据分片到多个文件中,并使用tf.data.Dataset.interleave ,并将num_parallel_calls参数设置为tf.data.experimental.AUTOTUNE 。如果确定性对于您的程序不重要,则可以通过从TF 2.2开始在tf.data.Dataset.interleave上设置deterministic=False标志来进一步提高性能。例如,如果您正在读取TFRecords,则可以执行以下操作:

 dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.interleave(tf.data.TFRecordDataset,
  num_parallel_calls=tf.data.experimental.AUTOTUNE,
  deterministic=False)
 

请注意,分片文件应足够大,以分摊打开文件的开销。有关并行数据提取的更多详细信息,请参阅tf.data性能指南的此部分

转换数据集

如果您将tf.data中间转换识别为瓶颈,则可以通过并行化转换或缓存计算(如果数据适合内存)来解决该问题。一些转换(例如Map具有并行的对应项; 《 tf.data性能指南》演示了如何并行处理这些问题。其他转换(例如FilterUnbatchBatch本质上是顺序的。您可以通过引入“外部并行性”来并行化它们。例如,假设您的输入管道最初看起来如下所示,并且将Batch作为瓶颈:

 filenames = tf.data.Dataset.list_files(file_path, shuffle=is_training)
dataset = filenames_to_dataset(filenames)
dataset = dataset.batch(batch_size)
 

您可以通过在分片输入上运行输入管道的多个副本并组合结果来引入“外部并行性”:

 filenames = tf.data.Dataset.list_files(file_path, shuffle=is_training)

def make_dataset(shard_index):
  filenames = filenames.shard(NUM_SHARDS, shard_index)
  dataset = filenames_to_dataset(filenames)
  Return dataset.batch(batch_size)

indices = tf.data.Dataset.range(NUM_SHARDS)
dataset = indices.interleave(make_dataset,
                             num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
 

额外资源