Missed TensorFlow Dev Summit? Check out the video playlist. Watch recordings

tfp.experimental.nn.util.tune_dataset

View source on GitHub

Sets generally recommended parameters for a tf.data.Dataset.

tfp.experimental.nn.util.tune_dataset(
    dataset, batch_size=None, shuffle_size=None, preprocess_fn=None, repeat_count=-1
)

Args:

  • dataset: tf.data.Dataset-like instance to be tuned according to this functions arguments.
  • batch_size: Python int representing the number of elements in each minibatch.
  • shuffle_size: Python int representing the number of elements to shuffle (at a time).
  • preprocess_fn: Python callable applied to each item in dataset.
  • repeat_count: Python int, representing the number of times the dataset should be repeated. The default behavior (repeat_count = -1) is for the dataset to be repeated indefinitely. If repeat_count is None repeat is "off;" note that this is a deviation from tf.data.Dataset.repeat which interprets None as "repeat indefinitely". Default value: -1 (i.e., repeat indefinitely).

Returns:

  • tuned_dataset: tf.data.Dataset instance tuned according to this functions arguments.

Example

[train_dataset, eval_dataset], datasets_info = tfds.load(
     name='mnist',
     split=['train', 'test'],
     with_info=True,
     as_supervised=True,
     shuffle_files=True)

def _preprocess(image, label):
  image = tf.cast(image, dtype=tf.int32)
  u = tf.random.uniform(shape=tf.shape(image), maxval=256, dtype=image.dtype)
  image = tf.cast(u < image, dtype=tf.float32)   # Randomly binarize.
  return image, label

# TODO(b/144500779): Cant use `experimental_compile=True`.
@tf.function(autograph=False)
def one_step(iter):
  x, y = next(iter)
  return tf.reduce_mean(x)

ds = tune_dataset(
    train_dataset,
    batch_size=32,
    shuffle_size=int(datasets_info.splits['train'].num_examples / 7),
    preprocess_fn=_preprocess)
it = iter(ds)
[one_step(it)]*3  # Build graph / burn-in.
%time one_step(it)