tff.analytics.data_processing.to_stacked_tensor

Encodes the `tf.data.Dataset as stacked tensors.

This is effectively the inverse of tf.data.Dataset.from_tensor_slices(). All elements from the input dataset are concatenated into a tensor structure, where the output structure matches the input ds.element_spec, and each output tensor will have the same shape plus one additional prefix dimension which elements are stacked in. For example, if the dataset contains 5 elements with shape [3, 2], the returned tensor will have shape [5, 3, 2]. Note that each element in the dataset could be as single tensor or a structure of tensors.

Dataset elements must have fully-defined shapes. Any partially-defined element shapes will raise an error. If passing in a batched dataset, use drop_remainder=True to ensure the batched shape is fully defined.

ds The input tf.data.Dataset to stack.

A structure of tensors encoding the input dataset.

ValueError If any dataset element shape is not fully-defined.