diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index ae5cf71a95d..093e40ca904 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -44,6 +44,7 @@ You can also create your own datasets using the provided :ref:`base classes str: return "".join(random.choice(digits) for _ in range(length)) +def make_fake_pfm_file(h, w, file_name): + values = list(range(3 * h * w)) + # Note: we pack everything in little endian: -1.0, and "<" + content = f"PF \n{w} {h} \n-1.0\n".encode() + struct.pack("<" + "f" * len(values), *values) + with open(file_name, "wb") as f: + f.write(content) + + def make_fake_flo_file(h, w, file_name): """Creates a fake flow file in .flo format.""" values = list(range(2 * h * w)) diff --git a/test/test_datasets.py b/test/test_datasets.py index e355cfc5b40..fa55a0f2d5b 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2048,5 +2048,72 @@ def test_flow(self, config): np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape)) +class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.FlyingThings3D + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( + split=("train", "test"), pass_name=("clean", "final", "both"), camera=("left", "right", "both") + ) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None))) + + FLOW_H, FLOW_W = 3, 4 + + def inject_fake_data(self, tmpdir, config): + root = pathlib.Path(tmpdir) / "FlyingThings3D" + + num_images_per_camera = 3 if config["split"] == "train" else 4 + passes = ("frames_cleanpass", "frames_finalpass") + splits = ("TRAIN", "TEST") + letters = ("A", "B", "C") + subfolders = ("0000", "0001") + cameras = ("left", "right") + for pass_name, split, letter, subfolder, camera in itertools.product( + passes, splits, letters, subfolders, cameras + ): + current_folder = root / pass_name / split / letter / subfolder + datasets_utils.create_image_folder( + current_folder, + name=camera, + file_name_fn=lambda image_idx: f"00{image_idx}.png", + num_examples=num_images_per_camera, + ) + + directions = ("into_future", "into_past") + for split, letter, subfolder, direction, camera in itertools.product( + splits, letters, subfolders, directions, cameras + ): + current_folder = root / "optical_flow" / split / letter / subfolder / direction / camera + os.makedirs(str(current_folder), exist_ok=True) + for i in range(num_images_per_camera): + datasets_utils.make_fake_pfm_file(self.FLOW_H, self.FLOW_W, file_name=str(current_folder / f"{i}.pfm")) + + num_cameras = 2 if config["camera"] == "both" else 1 + num_passes = 2 if config["pass_name"] == "both" else 1 + num_examples = ( + (num_images_per_camera - 1) * num_cameras * len(subfolders) * len(letters) * len(splits) * num_passes + ) + return num_examples + + @datasets_utils.test_all_configs + def test_flow(self, config): + with self.create_dataset(config=config) as (dataset, _): + assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) + for _, _, flow in dataset: + assert flow.shape == (2, self.FLOW_H, self.FLOW_W) + # We don't check the values because the reshaping and flipping makes it hard to figure out + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"): + with self.create_dataset(pass_name="bad"): + pass + + with pytest.raises(ValueError, match="Unknown value 'bad' for argument camera"): + with self.create_dataset(camera="bad"): + pass + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index dfad4770a93..e057c45364d 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,4 +1,4 @@ -from ._optical_flow import KittiFlow, Sintel, FlyingChairs +from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 @@ -75,4 +75,5 @@ "KittiFlow", "Sintel", "FlyingChairs", + "FlyingThings3D", ) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index f26127039d1..6ff49395a0a 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -1,4 +1,6 @@ +import itertools import os +import re from abc import ABC, abstractmethod from glob import glob from pathlib import Path @@ -15,6 +17,7 @@ __all__ = ( "KittiFlow", "Sintel", + "FlyingThings3D", "FlyingChairs", ) @@ -271,6 +274,94 @@ def _read_flow(self, file_name): return _read_flo(file_name) +class FlyingThings3D(FlowDataset): + """`FlyingThings3D `_ dataset for optical flow. + + The dataset is expected to have the following structure: :: + + root + FlyingThings3D + frames_cleanpass + TEST + TRAIN + frames_finalpass + TEST + TRAIN + optical_flow + TEST + TRAIN + + Args: + root (string): Root directory of the intel FlyingThings3D Dataset. + split (string, optional): The dataset split, either "train" (default) or "test" + pass_name (string, optional): The pass to use, either "clean" (default) or "final" or "both". See link above for + details on the different passes. + camera (string, optional): Which camera to return images from. Can be either "left" (default) or "right" or "both". + transforms (callable, optional): A function/transform that takes in + ``img1, img2, flow, valid`` and returns a transformed version. + ``valid`` is expected for consistency with other datasets which + return a built-in valid mask, such as :class:`~torchvision.datasets.KittiFlow`. + """ + + def __init__(self, root, split="train", pass_name="clean", camera="left", transforms=None): + super().__init__(root=root, transforms=transforms) + + verify_str_arg(split, "split", valid_values=("train", "test")) + split = split.upper() + + verify_str_arg(pass_name, "pass_name", valid_values=("clean", "final", "both")) + passes = { + "clean": ["frames_cleanpass"], + "final": ["frames_finalpass"], + "both": ["frames_cleanpass", "frames_finalpass"], + }[pass_name] + + verify_str_arg(camera, "camera", valid_values=("left", "right", "both")) + cameras = ["left", "right"] if camera == "both" else [camera] + + root = Path(root) / "FlyingThings3D" + + directions = ("into_future", "into_past") + for pass_name, camera, direction in itertools.product(passes, cameras, directions): + image_dirs = sorted(glob(str(root / pass_name / split / "*/*"))) + image_dirs = sorted([Path(image_dir) / camera for image_dir in image_dirs]) + + flow_dirs = sorted(glob(str(root / "optical_flow" / split / "*/*"))) + flow_dirs = sorted([Path(flow_dir) / direction / camera for flow_dir in flow_dirs]) + + if not image_dirs or not flow_dirs: + raise FileNotFoundError( + "Could not find the FlyingThings3D flow images. " + "Please make sure the directory structure is correct." + ) + + for image_dir, flow_dir in zip(image_dirs, flow_dirs): + images = sorted(glob(str(image_dir / "*.png"))) + flows = sorted(glob(str(flow_dir / "*.pfm"))) + for i in range(len(flows) - 1): + if direction == "into_future": + self._image_list += [[images[i], images[i + 1]]] + self._flow_list += [flows[i]] + elif direction == "into_past": + self._image_list += [[images[i + 1], images[i]]] + self._flow_list += [flows[i + 1]] + + def __getitem__(self, index): + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: A 3-tuple with ``(img1, img2, flow)``. + The flow is a numpy array of shape (2, H, W) and the images are PIL images. + """ + return super().__getitem__(index) + + def _read_flow(self, file_name): + return _read_pfm(file_name) + + def _read_flo(file_name): """Read .flo file in Middlebury format""" # Code adapted from: @@ -295,3 +386,31 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name): # For consistency with other datasets, we convert to numpy return flow.numpy(), valid.numpy() + + +def _read_pfm(file_name): + """Read flow in .pfm format""" + + with open(file_name, "rb") as f: + header = f.readline().rstrip() + if header != b"PF": + raise ValueError("Invalid PFM file") + + dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline()) + if not dim_match: + raise Exception("Malformed PFM header.") + w, h = (int(dim) for dim in dim_match.groups()) + + scale = float(f.readline().rstrip()) + if scale < 0: # little-endian + endian = "<" + scale = -scale + else: + endian = ">" # big-endian + + data = np.fromfile(f, dtype=endian + "f") + + data = data.reshape(h, w, 3).transpose(2, 0, 1) + data = np.flip(data, axis=1) # flip on h dimension + data = data[:2, :, :] + return data.astype(np.float32)