![]() |
Returns the subsplit of the data for the process.
tfds.split_for_jax_process(
split: str,
*,
process_index: tfds.typing.Dim
= None,
process_count: tfds.typing.Dim
= None,
drop_remainder: bool = False
) -> tfds.typing.SplitArg
In distributed setting, all process/hosts should get a non-overlapping, equally sized slice of the entire data. This function takes as input a split and extracts the slice for the current process index.
Usage:
tfds.load(..., split=tfds.split_for_jax_process('train'))
This funtion is an alias for:
tfds.even_splits(split, n=jax.process_count())[jax.process_index()]
By default, if examples can't be evenly distributed across processes, you can
drop extra examples with drop_remainder=True
.
Args | |
---|---|
split
|
Split to distribute across host (e.g. train[75%:] ,
train[:800]+validation[:100] ).
|
process_index
|
Process index in [0, count) . Defaults to
jax.process_index() .
|
process_count
|
Number of processes. Defaults to jax.process_count() .
|
drop_remainder
|
Drop examples if the number of examples in the datasets is
not evenly divisible by n . If False , examples are distributed evenly
across subsplits, starting by the first. For example, if there is 11
examples with n=3 , splits will contain [4, 4, 3] examples
respectivelly.
|
Returns | |
---|---|
subsplit
|
The sub-split of the given split for the current
process_index .
|