View source on GitHub

Visualize images (and labels) from an image classification dataset.

    ds_info, ds, rows=3, cols=3, plot_scale=3.0, image_key=None

Only works with datasets that have 1 image feature and optionally 1 label feature (both inferred from ds_info). Note the dataset should be unbatched. Requires matplotlib to be installed.

This function is for interactive use (Colab, Jupyter). It displays and return a plot of (rows*columns) images from a


ds, ds_info = tfds.load('cifar10', split='train', with_info=True)
fig = tfds.show_examples(ds_info, ds)


  • ds_info: The dataset info object to which extract the label and features info. Available either through tfds.load('mnist', with_info=True) or tfds.builder('mnist').info
  • ds: The object to visualize. Examples should not be batched. Examples will be consumed in order until (rows * cols) are read or the dataset is consumed.
  • rows: int, number of rows of the display grid.
  • cols: int, number of columns of the display grid.
  • plot_scale: float, controls the plot size of the images. Keep this value around 3 to get a good plot. High and low values may cause the labels to get overlapped.
  • image_key: string, name of the feature that contains the image. If not set, the system will try to auto-detect it.


  • fig: The matplotlib.Figure object