Skip to content

Add HD1K dataset for optical flow #4890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
Flickr30k
FlyingChairs
FlyingThings3D
HD1K
HMDB51
ImageNet
INaturalist
Expand Down
42 changes: 42 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2126,5 +2126,47 @@ def test_bad_input(self):
pass


class HD1KTestCase(KittiFlowTestCase):
DATASET_CLASS = datasets.HD1K

def inject_fake_data(self, tmpdir, config):
root = pathlib.Path(tmpdir) / "hd1k"

num_sequences = 4 if config["split"] == "train" else 3
num_examples_per_train_sequence = 3

for seq_idx in range(num_sequences):
# Training data
datasets_utils.create_image_folder(
root / "hd1k_input",
name="image_2",
file_name_fn=lambda image_idx: f"{seq_idx:06d}_{image_idx}.png",
num_examples=num_examples_per_train_sequence,
)
datasets_utils.create_image_folder(
root / "hd1k_flow_gt",
name="flow_occ",
file_name_fn=lambda image_idx: f"{seq_idx:06d}_{image_idx}.png",
num_examples=num_examples_per_train_sequence,
)

# Test data
datasets_utils.create_image_folder(
root / "hd1k_challenge",
name="image_2",
file_name_fn=lambda _: f"{seq_idx:06d}_10.png",
num_examples=1,
)
datasets_utils.create_image_folder(
root / "hd1k_challenge",
name="image_2",
file_name_fn=lambda _: f"{seq_idx:06d}_11.png",
num_examples=1,
)

num_examples_per_sequence = num_examples_per_train_sequence if config["split"] == "train" else 2
return num_sequences * (num_examples_per_sequence - 1)


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D
from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D, HD1K
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import CIFAR10, CIFAR100
Expand Down Expand Up @@ -76,4 +76,5 @@
"Sintel",
"FlyingChairs",
"FlyingThings3D",
"HD1K",
)
68 changes: 68 additions & 0 deletions torchvision/datasets/_optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"Sintel",
"FlyingThings3D",
"FlyingChairs",
"HD1K",
)


Expand Down Expand Up @@ -363,6 +364,73 @@ def _read_flow(self, file_name):
return _read_pfm(file_name)


class HD1K(FlowDataset):
"""`HD1K <http://hci-benchmark.iwr.uni-heidelberg.de/>`__ dataset for optical flow.

The dataset is expected to have the following structure: ::

root
hd1k
hd1k_challenge
image_2
hd1k_flow_gt
flow_occ
hd1k_input
image_2

Args:
root (string): Root directory of the HD1K Dataset.
split (string, optional): The dataset split, either "train" (default) or "test"
transforms (callable, optional): A function/transform that takes in
``img1, img2, flow, valid`` and returns a transformed version.
"""

_has_builtin_flow_mask = True

def __init__(self, root, split="train", transforms=None):
super().__init__(root=root, transforms=transforms)

verify_str_arg(split, "split", valid_values=("train", "test"))

root = Path(root) / "hd1k"
if split == "train":
# There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop
for seq_idx in range(36):
flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png")))
images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png")))
for i in range(len(flows) - 1):
self._flow_list += [flows[i]]
self._image_list += [[images[i], images[i + 1]]]
else:
images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png")))
images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png")))
for image1, image2 in zip(images1, images2):
self._image_list += [[image1, image2]]

if not self._image_list:
raise FileNotFoundError(
"Could not find the HD1K images. Please make sure the directory structure is correct."
)

def _read_flow(self, file_name):
return _read_16bits_png_with_flow_and_valid_mask(file_name)

def __getitem__(self, index):
"""Return example at given index.

Args:
index(int): The index of the example to retrieve

Returns:
tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow,
valid)`` where ``valid`` is a numpy boolean mask of shape (H, W)
indicating which flow values are valid. The flow is a numpy array of
shape (2, H, W) and the images are PIL images. If `split="test"`, a
4-tuple with ``(img1, img2, None, None)`` is returned.
"""
return super().__getitem__(index)


def _read_flo(file_name):
"""Read .flo file in Middlebury format"""
# Code adapted from:
Expand Down