Sets generally recommended parameters for a tf.data.Dataset
.
tfp.experimental.nn.util.tune_dataset(
dataset,
batch_size=None,
shuffle_size=None,
preprocess_fn=None,
repeat_count=-1
)
Args |
dataset
|
tf.data.Dataset -like instance to be tuned according to this
functions arguments.
|
batch_size
|
Python int representing the number of elements in each
minibatch.
|
shuffle_size
|
Python int representing the number of elements to shuffle
(at a time).
|
preprocess_fn
|
Python callable applied to each item in dataset .
|
repeat_count
|
Python int , representing the number of times the dataset
should be repeated. The default behavior (repeat_count = -1 ) is for the
dataset to be repeated indefinitely. If repeat_count is None repeat is
"off;" note that this is a deviation from tf.data.Dataset.repeat which
interprets None as "repeat indefinitely".
Default value: -1 (i.e., repeat indefinitely).
|
Returns |
tuned_dataset
|
tf.data.Dataset instance tuned according to this functions
arguments.
|
Example
[train_dataset, eval_dataset], datasets_info = tfds.load(
name='mnist',
split=['train', 'test'],
with_info=True,
as_supervised=True,
shuffle_files=True)
def _preprocess(image, label):
image = tf.cast(image, dtype=tf.int32)
u = tf.random.uniform(shape=tf.shape(image), maxval=256, dtype=image.dtype)
image = tf.cast(u < image, dtype=tf.float32) # Randomly binarize.
return image, label
@tf.function(autograph=False)
def one_step(iter):
x, y = next(iter)
return tf.reduce_mean(x)
ds = tune_dataset(
train_dataset,
batch_size=32,
shuffle_size=int(datasets_info.splits['train'].num_examples / 7),
preprocess_fn=_preprocess)
it = iter(ds)
[one_step(it)]*3 # Build graph / burn-in.
%time one_step(it)