Dataset
Building towards: MNIST dataset
PyTorch expects datasets that support __getitem__ and __len__ methods.
Minimal 'non-null' dataset
Below is an example of perhaps the simplest 'non-null' dataset, which just returns the index of the sample requested.
Note that when using only the dataset, we do not yet get torch.Tensor, but normal python integers.
The DataLoader converts samples to batches, sometimes automatically, sometimes requiring a custom collate function.
import torch
from torch.utils.data import Dataset, DataLoader
# Basic dataset that just returns the idx called
class Ids(Dataset[int]):
def __init__(self, total: int = 10):
self.total = total
def __len__(self):
return self.total
def __getitem__(self, idx) -> int:
return idx
dataset = Ids()
print("\n\n--- Samples ---\n")
samples = list(dataset)
print(samples)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
print("\n\n--- Batches ---\n")
batches = list(DataLoader(dataset, shuffle=False, batch_size=3))
print(batches)
# [tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([9])]
MNIST
What does the dataset from the MNIST example look like?
The relevant lines:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('../data', train = True, download = True, transform = transform)
dataset2 = datasets.MNIST('../data', train = False, download = False, transform = transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
VisionDataset
Since datasets.MNIST inherits from VisionDataset (class MNIST(VisionDataset)), let us check VisionDataset first:
It is a base class used to handle image and target transformations. Other datasets that use it include FasionMNIST, etc.
A full list of torchvision.datasets can be checked by seeing the .venv/lib/python3.11/site-packages/torchvision/datasets/__init__.py from torchvision:
Contents of `.venv/lib/python3.11/site-packages/torchvision/datasets/__init__.py`
from ._optical_flow import FlyingChairs, FlyingThings3D, HD1K, KittiFlow, Sintel
from ._stereo_matching import (
CarlaStereo,
CREStereo,
ETH3DStereo,
FallingThingsStereo,
InStereo2k,
Kitti2012Stereo,
Kitti2015Stereo,
Middlebury2014Stereo,
SceneFlowStereo,
SintelStereo,
)
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100
from .cityscapes import Cityscapes
from .clevr import CLEVRClassification
from .coco import CocoCaptions, CocoDetection
from .country211 import Country211
from .dtd import DTD
from .eurosat import EuroSAT
from .fakedata import FakeData
from .fer2013 import FER2013
from .fgvc_aircraft import FGVCAircraft
from .flickr import Flickr30k, Flickr8k
from .flowers102 import Flowers102
from .folder import DatasetFolder, ImageFolder
from .food101 import Food101
from .gtsrb import GTSRB
from .hmdb51 import HMDB51
from .imagenet import ImageNet
from .imagenette import Imagenette
from .inaturalist import INaturalist
from .kinetics import Kinetics
from .kitti import Kitti
from .lfw import LFWPairs, LFWPeople
from .lsun import LSUN, LSUNClass
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
from .moving_mnist import MovingMNIST
from .omniglot import Omniglot
from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM
from .phototour import PhotoTour
from .places365 import Places365
from .rendered_sst2 import RenderedSST2
from .sbd import SBDataset
from .sbu import SBU
from .semeion import SEMEION
from .stanford_cars import StanfordCars
from .stl10 import STL10
from .sun397 import SUN397
from .svhn import SVHN
from .ucf101 import UCF101
from .usps import USPS
from .vision import VisionDataset
from .voc import VOCDetection, VOCSegmentation
from .widerface import WIDERFace
class VisionDataset(data.Dataset):
"""
Base Class For making datasets which are compatible with torchvision.
It is necessary to override the ``__getitem__`` and ``__len__`` method.
Args:
root (string, optional): Root directory of dataset. Only used for `__repr__`.
transforms (callable, optional): A function/transforms that takes in
an image and a label and returns the transformed versions of both.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
.. note::
:attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
"""
def __init__(self, ...):
...
# for backwards-compatibility
self.transform = transform
self.target_transform = target_transform
if has_separate_transform:
transforms = StandardTransform(transform, target_transform)
self.transforms = transforms
...
def __getitem__(self, index: int) -> Any:
raise NotImplementedError
def __len__(self) -> int:
raise NotImplementedError
datasets.MNIST
The essentials of MNIST:
class MNIST(VisionDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
Args:
root (str or ``pathlib.Path``): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte``
and ``MNIST/raw/t10k-images-idx3-ubyte`` exist.
train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``,
otherwise from ``t10k-images-idx3-ubyte``.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
mirrors = [
"https://ossci-datasets.s3.amazonaws.com/mnist/",
"http://yann.lecun.com/exdb/mnist/",
]
resources = [
("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
]
training_file = "training.pt"
test_file = "test.pt"
classes = [
"0 - zero",
"1 - one",
"2 - two",
"3 - three",
"4 - four",
"5 - five",
"6 - six",
"7 - seven",
"8 - eight",
"9 - nine",
]
def __init__(self, ...):
super().__init__(root, transform=transform, target_transform=target_transform)
self.train = train # training set or test set
...
self.data, self.targets = self._load_data()
...
def _load_data(self):
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
data = read_image_file(os.path.join(self.raw_folder, image_file))
label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
targets = read_label_file(os.path.join(self.raw_folder, label_file))
return data, targets
def __getitem__(self, index: int) -> tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = _Image_fromarray(img.numpy(), mode="L")
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.data)
Observations:
- Given the nature of the dataset, the splits to be loaded are hardcoded in
_load_data(self) - Labels are hardcoded, as are names of training and test files
- All images are of the same shape, thus allowing for use of the
DataLoader - Since we are dealing with a
VisionDataset, we have to think about transforms (preprocessing vs augmentation, batching, etc.) - The entire dataset is loaded at once at initialization in
_load_data(self)
Recommendations
Having been through the Big Data course, it should be clear how it would be advantageous to
- define various splits dynamically
- sample based on class distributions
- not load the entire dataset at once
- add various metadata to each sample as pseudo-supervision
- visualize the dataset based on various filters and sortings
All this can be unlocked by using a reference sql database as the source of samples for the dataset.
Structure of recommended dataset
In essence, something like:
from typing import Callable
from pathlib import Path
import pandas as pd
from PIL import Image
# Function that loads an sqlite3 database table as a pandas dataframe
from thesis.utils.db import load_db
# Base class to handle image and target transforms
from torchvision.datasets.vision import VisionDataset
class MNIST(VisionDataset):
data_dir: Path
db_path: str
data: pd.DataFrame
labels: pd.DataFrame
def __init__(
self,
data_dir: Path,
name: str,
split: str,
seed: int,
transform: Callable | None = None,
target_transform: Callable | None = None,
):
# Initialize VisionDataset
super().__init__(data_dir, transform=transform, target_transform=target_transform)
self.data_dir = data_dir
# Choose fold in case of n-fold cross-validation of the dataset if needed
# self.db_path = f"{name}-{seed}.db"
del seed # Not using it for now
self.db_path = f"{name}.db"
# Load dataframes from database tables
# df = load_db(db_path, db_table)
self.data = load_db(self.data_dir / self.db_path, "dataset")
self.labels = load_db(self.data_dir / self.db_path, "labels")
# Apply specific query
data = self.data[self.data.split == split]
# Set filtered data
assert isinstance(data, pd.DataFrame)
self.data = data
def __getitem__(self, index: int) -> tuple[Image.Image, int]:
# Retrieve row and all associated metadata
row = self.data.iloc[index]
# Read single sample
img = Image.open(self.data_dir / row.image_path)
target = row.label
# `VisionDataset` specific steps
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.data)
Usage:
from thesis import Paths # Defined in src/thesis/__init__.py
from thesis.datasets import MNIST # Defined above, stored under src/thesis/datasets.py
if __name__ == "__main__":
data_dir = Paths.data / "MNIST"
dataset = MNIST(
data_dir=data_dir,
name="mnist-dataset",
split="train",
seed=10,
)
print(f"{len(dataset)} samples\n")
for i in range(5):
img, label = dataset[i]
print(f" image: {img}")
print(f" label: {label}")
Observations:
- Loading of splits are no longer dependent on file structure, thus generalizes over datasets
- Retains
VisionDatasetbehaviour - Only a precomputed dataset index (
mnist-dataset.db) is loaded, thus avoiding the need to load the whole dataset at once, and helping theDataLoaderuse multiple workers efficiently. - Can easily add, remove metadata from each row
- Can simply clone/copy the dataframe to add metrics for each epoch, etc.
- Needs an
index_datasetstep to create the db: allows to 'instantly' load the dataset without having to read/list all the files, expensive metadata computations can be done once and cached - Allows use of various tools to inspect dataset interactively
The flexibility of the second last point and the benefits of last point cannot be overstated!
Example: MNIST dataset
The following is the code for scripts/index_dataset_mnist.py.
Given a data directory, assumed to be ./data, it:
- downloads MNIST (TODO: )
- saves the data as images
- and creates
mnist-dataset.db
Once the data is indexed, it can be viewed like wiki.thesis-guidelines.nl using datasette as:
cd data/MNIST && \
datasette serve \
--host 0.0.0.0 \
--port 5555 \
--static images:images \
mnist-dataset.db
Example of visualizing MNIST using datasette: Code for `scripts/index_dataset_mnist.py`
from pathlib import Path
from thesis import Paths
from thesis.utils.db import load_db, save_db
from PIL import Image
import pandas as pd
from torchvision.datasets.mnist import read_image_file, read_label_file
def _load_data(data_dir: Path, train: bool):
"""Using pytorch's own mnist helpers to load data."""
image_file = f"{'train' if train else 't10k'}-images-idx3-ubyte"
data = read_image_file(str(data_dir / "raw" / image_file))
label_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte"
targets = read_label_file(str(data_dir / "raw" / label_file))
return data, targets
def index_images(data_dir: Path, force: bool = False):
"""Convert ubyte format to png images, csv labels."""
images_dir = data_dir / "images"
if images_dir.exists() and not force:
return
data_train, targets_train = _load_data(data_dir, train=True)
data_test, targets_test = _load_data(data_dir, train=False)
print(f"{data_train.shape=}, {targets_train.shape=}")
print(f"{data_test.shape=}, {targets_test.shape=}")
# data_train.shape=torch.Size([60000, 28, 28]), targets_train.shape=torch.Size([60000])
# data_test.shape=torch.Size([10000, 28, 28]), targets_test.shape=torch.Size([10000])
# Save images as png
images_dir = data_dir / "images"
images_dir.mkdir(exist_ok=True, parents=True)
# Save ground truth (labels) as csv
labels_dir = data_dir / "labels"
labels_dir.mkdir(exist_ok=True, parents=True)
for split, data, targets in zip(["train", "test"], [data_train, data_test], [targets_train, targets_test]):
for i in range(len(data)):
img = data[i].numpy()
img_pil = Image.fromarray(img, mode="L")
img_pil.save(images_dir / f"{split}_{i:06d}.png")
labels_df = pd.DataFrame(targets.numpy(), columns=["label"])
labels_df.to_csv(labels_dir / f"{split}.csv", index=False)
def index_files(data_dir: Path, db_path: Path, table: str, force: bool = False):
"""To check supported file extensions, etc."""
if db_path.exists() and not force:
try:
load_db(db_path, table, verbose=False)
return
except pd.errors.DatabaseError:
... # Continue
files = data_dir.glob("**/*")
rows = [
{
"filename": f.name,
"filestem": f.stem,
"extension": f.suffix,
"parent": f.parent.name,
"path": str(f.relative_to(data_dir)),
}
for f in files
]
df = pd.DataFrame(rows)
save_db(df, db_path, table, verbose=True)
def index_dataset(data_dir: Path, db_path: Path, table: str, force: bool = False):
"""Dataframe to load in Dataset, DataLoader with all possibly useful metadata."""
if db_path.exists() and not force:
try:
load_db(db_path, table, verbose=False)
return
except pd.errors.DatabaseError:
... # Continue
# Load supervision
labels_train = pd.read_csv(data_dir / "labels" / "train.csv")
labels_test = pd.read_csv(data_dir / "labels" / "test.csv")
# Index splits
rows = []
for split, labels in zip(["train", "test"], [labels_train, labels_test]):
for i in range(len(labels)):
image_path = (data_dir / "images" / f"{split}_{i:06d}.png").relative_to(data_dir)
row = {
"image_path": str(image_path),
"label": labels.iloc[i].label,
"split": split,
"image_html": f'{{"img_src": "/{str(image_path)}", "width": 200}}', # Integration with datasette
}
rows.append(row)
# Save `dataset` table to database
df = pd.DataFrame(rows)
save_db(df, db_path, table, verbose=True)
def index_labels(db_path: Path, table: str, force: bool = False):
"""Dataframe with class labels."""
if db_path.exists() and not force:
try:
load_db(db_path, table, verbose=False)
return
except pd.errors.DatabaseError:
... # Continue
classes = [
{"label_idx": 0, "label_str": "0 - zero"},
{"label_idx": 1, "label_str": "1 - one"},
{"label_idx": 2, "label_str": "2 - two"},
{"label_idx": 3, "label_str": "3 - three"},
{"label_idx": 4, "label_str": "4 - four"},
{"label_idx": 5, "label_str": "5 - five"},
{"label_idx": 6, "label_str": "6 - six"},
{"label_idx": 7, "label_str": "7 - seven"},
{"label_idx": 8, "label_str": "8 - eight"},
{"label_idx": 9, "label_str": "9 - nine"},
]
df = pd.DataFrame(classes)
save_db(df, db_path, table, verbose=True)
if __name__ == "__main__":
data_dir = Paths.data / "MNIST"
db_path = data_dir / "mnist-dataset.db"
index_images(data_dir)
index_files(data_dir, db_path, "files")
index_dataset(data_dir, db_path, "dataset")
index_labels(db_path, "labels")
Visualizing the dataset
Important to note that
>> Visualize = visualize(sample, ground truth, predictions)
We need a function that can accept a sample (loaded from a row in the dataset), the ground truth and the model predictions and return a figure/image.
If we allow the predictions to be None, we can also use the same function to visualize the dataset samples.
Examples of various modalities
Example 1: Features
For a dataset composed of features and label:
Example 2: Images
For a dataset composed of images with classification, segmentation and detection as supervision:
Example 3: Videos
For a dataset composed of videos with classification, segmentation, detection and tracking as supervision: