Skip to content

Dataset

Building towards: MNIST dataset

PyTorch expects datasets that support __getitem__ and __len__ methods.


Minimal 'non-null' dataset

Below is an example of perhaps the simplest 'non-null' dataset, which just returns the index of the sample requested. Note that when using only the dataset, we do not yet get torch.Tensor, but normal python integers.

The DataLoader converts samples to batches, sometimes automatically, sometimes requiring a custom collate function.

import torch
from torch.utils.data import Dataset, DataLoader

# Basic dataset that just returns the idx called
class Ids(Dataset[int]):
    def __init__(self, total: int = 10):
        self.total = total

    def __len__(self):
        return self.total

    def __getitem__(self, idx) -> int:
        return idx

dataset = Ids()

print("\n\n--- Samples ---\n")
samples = list(dataset)
print(samples)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

print("\n\n--- Batches ---\n")
batches = list(DataLoader(dataset, shuffle=False, batch_size=3))
print(batches)
# [tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9])]

MNIST

What does the dataset from the MNIST example look like?

The relevant lines:

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dataset1 = datasets.MNIST('../data', train = True,  download = True,  transform = transform)
dataset2 = datasets.MNIST('../data', train = False, download = False, transform = transform)

train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader  = torch.utils.data.DataLoader(dataset2, **test_kwargs)

VisionDataset

Since datasets.MNIST inherits from VisionDataset (class MNIST(VisionDataset)), let us check VisionDataset first:

It is a base class used to handle image and target transformations. Other datasets that use it include FasionMNIST, etc. A full list of torchvision.datasets can be checked by seeing the .venv/lib/python3.11/site-packages/torchvision/datasets/__init__.py from torchvision:

Contents of `.venv/lib/python3.11/site-packages/torchvision/datasets/__init__.py`
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
from ._stereo_matching import (
CarlaStereo,
CREStereo,
ETH3DStereo,
FallingThingsStereo,
InStereo2k,
Kitti2012Stereo,
Kitti2015Stereo,
Middlebury2014Stereo,
SceneFlowStereo,
SintelStereo,
)
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100
from .cityscapes import Cityscapes
from .clevr import CLEVRClassification
from .coco import CocoCaptions, CocoDetection
from .country211 import Country211
from .dtd import DTD
from .eurosat import EuroSAT
from .fakedata import FakeData
from .fer2013 import FER2013
from .fgvc_aircraft import FGVCAircraft
from .flickr import Flickr30k, Flickr8k
from .flowers102 import Flowers102
from .folder import DatasetFolder, ImageFolder
from .food101 import Food101
from .gtsrb import GTSRB
from .hmdb51 import HMDB51
from .imagenet import ImageNet
from .imagenette import Imagenette
from .inaturalist import INaturalist
from .kinetics import Kinetics
from .kitti import Kitti
from .lfw import LFWPairs, LFWPeople
from .lsun import LSUN, LSUNClass
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
from .moving_mnist import MovingMNIST
from .omniglot import Omniglot
from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM
from .phototour import PhotoTour
from .places365 import Places365
from .rendered_sst2 import RenderedSST2
from .sbd import SBDataset
from .sbu import SBU
from .semeion import SEMEION
from .stanford_cars import StanfordCars
from .stl10 import STL10
from .sun397 import SUN397
from .svhn import SVHN
from .ucf101 import UCF101
from .usps import USPS
from .vision import VisionDataset
from .voc import VOCDetection, VOCSegmentation
from .widerface import WIDERFace
class VisionDataset(data.Dataset):
    """
    Base Class For making datasets which are compatible with torchvision.
    It is necessary to override the ``__getitem__`` and ``__len__`` method.

    Args:
        root (string, optional): Root directory of dataset. Only used for `__repr__`.
        transforms (callable, optional): A function/transforms that takes in
            an image and a label and returns the transformed versions of both.
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.

    .. note::

        :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
    """

    def __init__(self, ...):
        ...
        # for backwards-compatibility
        self.transform = transform
        self.target_transform = target_transform

        if has_separate_transform:
            transforms = StandardTransform(transform, target_transform)
        self.transforms = transforms
        ...

    def __getitem__(self, index: int) -> Any:
        raise NotImplementedError

    def __len__(self) -> int:
        raise NotImplementedError

datasets.MNIST

The essentials of MNIST:

class MNIST(VisionDataset):
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

    Args:
        root (str or ``pathlib.Path``): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
            and  ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
        train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
            otherwise from ``t10k-images-idx3-ubyte``.
        transform (callable, optional): A function/transform that  takes in a PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If True, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

    mirrors = [
        "https://ossci-datasets.s3.amazonaws.com/mnist/",
        "http://yann.lecun.com/exdb/mnist/",
    ]

    resources = [
        ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
        ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
        ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
        ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
    ]

    training_file = "training.pt"
    test_file = "test.pt"
    classes = [
        "0 - zero",
        "1 - one",
        "2 - two",
        "3 - three",
        "4 - four",
        "5 - five",
        "6 - six",
        "7 - seven",
        "8 - eight",
        "9 - nine",
    ]

    def __init__(self, ...):
        super().__init__(root, transform=transform, target_transform=target_transform)
        self.train = train  # training set or test set
        ...
        self.data, self.targets = self._load_data()
        ...

    def _load_data(self):
        image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
        data = read_image_file(os.path.join(self.raw_folder, image_file))

        label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
        targets = read_label_file(os.path.join(self.raw_folder, label_file))

        return data, targets

    def __getitem__(self, index: int) -> tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = _Image_fromarray(img.numpy(), mode="L")

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self) -> int:
        return len(self.data)

Observations:

  • Given the nature of the dataset, the splits to be loaded are hardcoded in _load_data(self)
  • Labels are hardcoded, as are names of training and test files
  • All images are of the same shape, thus allowing for use of the DataLoader
  • Since we are dealing with a VisionDataset, we have to think about transforms (preprocessing vs augmentation, batching, etc.)
  • The entire dataset is loaded at once at initialization in _load_data(self)

Recommendations

Having been through the Big Data course, it should be clear how it would be advantageous to

  • define various splits dynamically
  • sample based on class distributions
  • not load the entire dataset at once
  • add various metadata to each sample as pseudo-supervision
  • visualize the dataset based on various filters and sortings

All this can be unlocked by using a reference sql database as the source of samples for the dataset.


In essence, something like:

from typing import Callable
from pathlib import Path

import pandas as pd
from PIL import Image

# Function that loads an sqlite3 database table as a pandas dataframe
from thesis.utils.db import load_db

# Base class to handle image and target transforms
from torchvision.datasets.vision import VisionDataset


class MNIST(VisionDataset):
    data_dir: Path
    db_path: str
    data: pd.DataFrame
    labels: pd.DataFrame

    def __init__(
        self,
        data_dir: Path,
        name: str,
        split: str,
        seed: int,
        transform: Callable | None = None,
        target_transform: Callable | None = None,
    ):

        # Initialize VisionDataset
        super().__init__(data_dir, transform=transform, target_transform=target_transform)

        self.data_dir = data_dir

        # Choose fold in case of n-fold cross-validation of the dataset if needed
        # self.db_path = f"{name}-{seed}.db"

        del seed  # Not using it for now
        self.db_path = f"{name}.db"

        # Load dataframes from database tables
        #   df = load_db(db_path, db_table)
        self.data = load_db(self.data_dir / self.db_path, "dataset")
        self.labels = load_db(self.data_dir / self.db_path, "labels")

        # Apply specific query
        data = self.data[self.data.split == split]

        # Set filtered data
        assert isinstance(data, pd.DataFrame)
        self.data = data

    def __getitem__(self, index: int) -> tuple[Image.Image, int]:

        # Retrieve row and all associated metadata
        row = self.data.iloc[index]

        # Read single sample
        img = Image.open(self.data_dir / row.image_path)
        target = row.label

        # `VisionDataset` specific steps
        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self) -> int:
        return len(self.data)

Usage:

from thesis import Paths # Defined in src/thesis/__init__.py
from thesis.datasets import MNIST # Defined above, stored under src/thesis/datasets.py

if __name__ == "__main__":
    data_dir = Paths.data / "MNIST"

    dataset = MNIST(
        data_dir=data_dir,
        name="mnist-dataset",
        split="train",
        seed=10,
    )

    print(f"{len(dataset)} samples\n")
    for i in range(5):
        img, label = dataset[i]
        print(f" image: {img}")
        print(f" label: {label}")

Observations:

  • Loading of splits are no longer dependent on file structure, thus generalizes over datasets
  • Retains VisionDataset behaviour
  • Only a precomputed dataset index (mnist-dataset.db) is loaded, thus avoiding the need to load the whole dataset at once, and helping the DataLoader use multiple workers efficiently.
  • Can easily add, remove metadata from each row
  • Can simply clone/copy the dataframe to add metrics for each epoch, etc.
  • Needs an index_dataset step to create the db: allows to 'instantly' load the dataset without having to read/list all the files, expensive metadata computations can be done once and cached
  • Allows use of various tools to inspect dataset interactively

The flexibility of the second last point and the benefits of last point cannot be overstated!


Example: MNIST dataset

The following is the code for scripts/index_dataset_mnist.py. Given a data directory, assumed to be ./data, it:

  • downloads MNIST (TODO: )
  • saves the data as images
  • and creates mnist-dataset.db

Once the data is indexed, it can be viewed like wiki.thesis-guidelines.nl using datasette as:

cd data/MNIST && \
    datasette serve \
        --host 0.0.0.0 \
        --port 5555 \
        --static images:images \
        mnist-dataset.db
Example of visualizing MNIST using datasette: Code for `scripts/index_dataset_mnist.py`
from pathlib import Path
from thesis import Paths
from thesis.utils.db import load_db, save_db

from PIL import Image
import pandas as pd

from torchvision.datasets.mnist import read_image_file, read_label_file


def _load_data(data_dir: Path, train: bool):
    """Using pytorch's own mnist helpers to load data."""

    image_file = f"{'train' if train else 't10k'}-images-idx3-ubyte"
    data = read_image_file(str(data_dir / "raw" / image_file))

    label_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte"
    targets = read_label_file(str(data_dir / "raw" / label_file))

    return data, targets


def index_images(data_dir: Path, force: bool = False):
    """Convert ubyte format to png images, csv labels."""

    images_dir = data_dir / "images"
    if images_dir.exists() and not force:
        return

    data_train, targets_train = _load_data(data_dir, train=True)
    data_test, targets_test = _load_data(data_dir, train=False)

    print(f"{data_train.shape=}, {targets_train.shape=}")
    print(f"{data_test.shape=}, {targets_test.shape=}")
    # data_train.shape=torch.Size([60000, 28, 28]), targets_train.shape=torch.Size([60000])
    # data_test.shape=torch.Size([10000, 28, 28]), targets_test.shape=torch.Size([10000])

    # Save images as png
    images_dir = data_dir / "images"
    images_dir.mkdir(exist_ok=True, parents=True)

    # Save ground truth (labels) as csv
    labels_dir = data_dir / "labels"
    labels_dir.mkdir(exist_ok=True, parents=True)

    for split, data, targets in zip(["train", "test"], [data_train, data_test], [targets_train, targets_test]):
        for i in range(len(data)):
            img = data[i].numpy()
            img_pil = Image.fromarray(img, mode="L")
            img_pil.save(images_dir / f"{split}_{i:06d}.png")

        labels_df = pd.DataFrame(targets.numpy(), columns=["label"])
        labels_df.to_csv(labels_dir / f"{split}.csv", index=False)


def index_files(data_dir: Path, db_path: Path, table: str, force: bool = False):
    """To check supported file extensions, etc."""

    if db_path.exists() and not force:
        try:
            load_db(db_path, table, verbose=False)
            return
        except pd.errors.DatabaseError:
            ...  # Continue

    files = data_dir.glob("**/*")
    rows = [
        {
            "filename": f.name,
            "filestem": f.stem,
            "extension": f.suffix,
            "parent": f.parent.name,
            "path": str(f.relative_to(data_dir)),
        }
        for f in files
    ]
    df = pd.DataFrame(rows)
    save_db(df, db_path, table, verbose=True)


def index_dataset(data_dir: Path, db_path: Path, table: str, force: bool = False):
    """Dataframe to load in Dataset, DataLoader with all possibly useful metadata."""

    if db_path.exists() and not force:
        try:
            load_db(db_path, table, verbose=False)
            return
        except pd.errors.DatabaseError:
            ...  # Continue

    # Load supervision
    labels_train = pd.read_csv(data_dir / "labels" / "train.csv")
    labels_test = pd.read_csv(data_dir / "labels" / "test.csv")

    # Index splits
    rows = []
    for split, labels in zip(["train", "test"], [labels_train, labels_test]):
        for i in range(len(labels)):
            image_path = (data_dir / "images" / f"{split}_{i:06d}.png").relative_to(data_dir)
            row = {
                "image_path": str(image_path),
                "label": labels.iloc[i].label,
                "split": split,
                "image_html": f'{{"img_src": "/{str(image_path)}", "width": 200}}',  # Integration with datasette
            }
            rows.append(row)

    # Save `dataset` table to database
    df = pd.DataFrame(rows)
    save_db(df, db_path, table, verbose=True)


def index_labels(db_path: Path, table: str, force: bool = False):
    """Dataframe with class labels."""

    if db_path.exists() and not force:
        try:
            load_db(db_path, table, verbose=False)
            return
        except pd.errors.DatabaseError:
            ...  # Continue

    classes = [
        {"label_idx": 0, "label_str": "0 - zero"},
        {"label_idx": 1, "label_str": "1 - one"},
        {"label_idx": 2, "label_str": "2 - two"},
        {"label_idx": 3, "label_str": "3 - three"},
        {"label_idx": 4, "label_str": "4 - four"},
        {"label_idx": 5, "label_str": "5 - five"},
        {"label_idx": 6, "label_str": "6 - six"},
        {"label_idx": 7, "label_str": "7 - seven"},
        {"label_idx": 8, "label_str": "8 - eight"},
        {"label_idx": 9, "label_str": "9 - nine"},
    ]

    df = pd.DataFrame(classes)
    save_db(df, db_path, table, verbose=True)


if __name__ == "__main__":

    data_dir = Paths.data / "MNIST"
    db_path = data_dir / "mnist-dataset.db"

    index_images(data_dir)
    index_files(data_dir, db_path, "files")
    index_dataset(data_dir, db_path, "dataset")
    index_labels(db_path, "labels")

Visualizing the dataset

Important to note that

>> Visualize = visualize(sample, ground truth, predictions)

We need a function that can accept a sample (loaded from a row in the dataset), the ground truth and the model predictions and return a figure/image.

If we allow the predictions to be None, we can also use the same function to visualize the dataset samples.


Examples of various modalities

Example 1: Features

For a dataset composed of features and label:

from PIL.Image import Image

Example 2: Images

For a dataset composed of images with classification, segmentation and detection as supervision:

Example 3: Videos

For a dataset composed of videos with classification, segmentation, detection and tracking as supervision: