tfp.experimental.nn.util.tune_dataset

Sets generally recommended parameters for a tf.data.Dataset.

dataset tf.data.Dataset-like instance to be tuned according to this functions arguments.
batch_size Python int representing the number of elements in each minibatch.
shuffle_size Python int representing the number of elements to shuffle (at a time).
preprocess_fn Python callable applied to each item in dataset.
repeat_count Python int, representing the number of times the dataset should be repeated. The default behavior (repeat_count = -1) is for the dataset to be repeated indefinitely. If repeat_count is None repeat is "off;" note that this is a deviation from tf.data.Dataset.repeat which interprets None as "repeat indefinitely". Default value: -1 (i.e., repeat indefinitely).

tuned_dataset tf.data.Dataset instance tuned according to this functions arguments.

Example

[train_dataset, eval_dataset], datasets_info = tfds.load(
     name='mnist',
     split=['train', 'test'],
     with_info=True,
     as_supervised=True,
     shuffle_files=True)

def _preprocess(image, label):
  image = tf.cast(image, dtype=tf.int32)
  u = tf.random.uniform(shape=tf.shape(image), maxval=256, dtype=image.dtype)
  image = tf.cast(u < image, dtype=tf.float32)   # Randomly binarize.
  return image, label


@tf.function(autograph=False)
def one_step(iter):
  x, y = next(iter)
  return tf.reduce_mean(x)

ds = tune_dataset(
    train_dataset,
    batch_size=32,
    shuffle_size=int(datasets_info.splits['train'].num_examples / 7),
    preprocess_fn=_preprocess)
it = iter(ds)
[one_step(it)]*3  # Build graph / burn-in.
%time one_step(it)