From 61032e88da14c502039e15b4d53510823d2f94fc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 9 Nov 2021 13:58:18 +0000 Subject: [PATCH 1/3] Add HD1K dataset for optical flow --- test/test_datasets.py | 66 ++++++++++++++++++++++++++ torchvision/datasets/__init__.py | 3 +- torchvision/datasets/_optical_flow.py | 68 +++++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 1 deletion(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index fc9363e89bb..36c933c7610 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2125,5 +2125,71 @@ def test_bad_input(self): pass +class HD1KFlowTestCase(datasets_utils.ImageDatasetTestCase): + DATASET_CLASS = datasets.HD1K + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) + FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) + + def inject_fake_data(self, tmpdir, config): + root = pathlib.Path(tmpdir) / "hd1k" + + num_sequences = 4 if config["split"] == "train" else 3 + num_examples_per_train_sequence = 3 + + for seq_idx in range(num_sequences): + # Training data + datasets_utils.create_image_folder( + root / "hd1k_input", + name="image_2", + file_name_fn=lambda image_idx: f"{seq_idx:06d}_{image_idx}.png", + num_examples=num_examples_per_train_sequence, + ) + datasets_utils.create_image_folder( + root / "hd1k_flow_gt", + name="flow_occ", + file_name_fn=lambda image_idx: f"{seq_idx:06d}_{image_idx}.png", + num_examples=num_examples_per_train_sequence, + ) + + # Test data + datasets_utils.create_image_folder( + root / "hd1k_challenge", + name="image_2", + file_name_fn=lambda _: f"{seq_idx:06d}_10.png", + num_examples=1, + ) + datasets_utils.create_image_folder( + root / "hd1k_challenge", + name="image_2", + file_name_fn=lambda _: f"{seq_idx:06d}_11.png", + num_examples=1, + ) + + num_examples_per_sequence = num_examples_per_train_sequence if config["split"] == "train" else 2 + return num_sequences * (num_examples_per_sequence - 1) + + def test_flow_and_valid(self): + # Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images + # Also assert flow and valid are of the expected shape + with self.create_dataset(split="train") as (dataset, _): + assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) + for _, _, flow, valid in dataset: + two, h, w = flow.shape + assert two == 2 + assert valid.shape == (h, w) + + # Make sure flow and valid are always None for test split + with self.create_dataset(split="test") as (dataset, _): + assert dataset._image_list and not dataset._flow_list + for _, _, flow, valid in dataset: + assert flow is None + assert valid is None + + def test_bad_input(self): + with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): + with self.create_dataset(split="bad"): + pass + + if __name__ == "__main__": unittest.main() diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index e057c45364d..80859791004 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,4 +1,4 @@ -from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D +from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D, HD1K from .caltech import Caltech101, Caltech256 from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 @@ -76,4 +76,5 @@ "Sintel", "FlyingChairs", "FlyingThings3D", + "HD1K", ) diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index 7c728a5af8f..a1deba7600f 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -19,6 +19,7 @@ "Sintel", "FlyingThings3D", "FlyingChairs", + "HD1K", ) @@ -362,6 +363,73 @@ def _read_flow(self, file_name): return _read_pfm(file_name) +class HD1K(FlowDataset): + """`HD1K `__ dataset for optical flow. + + The dataset is expected to have the following structure: :: + + root + hd1k + hd1k_challenge + image_2 + hd1k_flow_gt + flow_occ + hd1k_input + image_2 + + Args: + root (string): Root directory of the HD1K Dataset. + split (string, optional): The dataset split, either "train" (default) or "test" + transforms (callable, optional): A function/transform that takes in + ``img1, img2, flow, valid`` and returns a transformed version. + """ + + _has_builtin_flow_mask = True + + def __init__(self, root, split="train", transforms=None): + super().__init__(root=root, transforms=transforms) + + verify_str_arg(split, "split", valid_values=("train", "test")) + + root = Path(root) / "hd1k" + if split == "train": + # There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop + for seq_idx in range(36): + flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png"))) + images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png"))) + for i in range(len(flows) - 1): + self._flow_list += [flows[i]] + self._image_list += [[images[i], images[i + 1]]] + else: + images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png"))) + images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png"))) + for image1, image2 in zip(images1, images2): + self._image_list += [[image1, image2]] + + if not self._image_list: + raise FileNotFoundError( + "Could not find the HD1K images. Please make sure the directory structure is correct." + ) + + def _read_flow(self, file_name): + return _read_16bits_png_with_flow_and_valid_mask(file_name) + + def __getitem__(self, index): + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow, + valid)`` where ``valid`` is a numpy boolean mask of shape (H, W) + indicating which flow values are valid. The flow is a numpy array of + shape (2, H, W) and the images are PIL images. If `split="test"`, a + 4-tuple with ``(img1, img2, None, None)`` is returned. + """ + return super().__getitem__(index) + + def _read_flo(file_name): """Read .flo file in Middlebury format""" # Code adapted from: From fa544a8e65f1ab0147094dc457859efcdce0825d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 9 Nov 2021 15:41:05 +0000 Subject: [PATCH 2/3] Add docs --- docs/source/datasets.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 093e40ca904..7f09ff245ca 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -45,6 +45,7 @@ You can also create your own datasets using the provided :ref:`base classes Date: Tue, 9 Nov 2021 15:47:06 +0000 Subject: [PATCH 3/3] simplify tests --- test/test_datasets.py | 26 +------------------------- 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index cac0e829e42..761f11d77dc 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2126,10 +2126,8 @@ def test_bad_input(self): pass -class HD1KFlowTestCase(datasets_utils.ImageDatasetTestCase): +class HD1KTestCase(KittiFlowTestCase): DATASET_CLASS = datasets.HD1K - ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) - FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None))) def inject_fake_data(self, tmpdir, config): root = pathlib.Path(tmpdir) / "hd1k" @@ -2169,28 +2167,6 @@ def inject_fake_data(self, tmpdir, config): num_examples_per_sequence = num_examples_per_train_sequence if config["split"] == "train" else 2 return num_sequences * (num_examples_per_sequence - 1) - def test_flow_and_valid(self): - # Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images - # Also assert flow and valid are of the expected shape - with self.create_dataset(split="train") as (dataset, _): - assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list) - for _, _, flow, valid in dataset: - two, h, w = flow.shape - assert two == 2 - assert valid.shape == (h, w) - - # Make sure flow and valid are always None for test split - with self.create_dataset(split="test") as (dataset, _): - assert dataset._image_list and not dataset._flow_list - for _, _, flow, valid in dataset: - assert flow is None - assert valid is None - - def test_bad_input(self): - with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"): - with self.create_dataset(split="bad"): - pass - if __name__ == "__main__": unittest.main()