diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 093e40ca904..7f09ff245ca 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -45,6 +45,7 @@ You can also create your own datasets using the provided :ref:`base classes `__ dataset for optical flow. + + The dataset is expected to have the following structure: :: + + root + hd1k + hd1k_challenge + image_2 + hd1k_flow_gt + flow_occ + hd1k_input + image_2 + + Args: + root (string): Root directory of the HD1K Dataset. + split (string, optional): The dataset split, either "train" (default) or "test" + transforms (callable, optional): A function/transform that takes in + ``img1, img2, flow, valid`` and returns a transformed version. + """ + + _has_builtin_flow_mask = True + + def __init__(self, root, split="train", transforms=None): + super().__init__(root=root, transforms=transforms) + + verify_str_arg(split, "split", valid_values=("train", "test")) + + root = Path(root) / "hd1k" + if split == "train": + # There are 36 "sequences" and we don't want seq i to overlap with seq i + 1, so we need this for loop + for seq_idx in range(36): + flows = sorted(glob(str(root / "hd1k_flow_gt" / "flow_occ" / f"{seq_idx:06d}_*.png"))) + images = sorted(glob(str(root / "hd1k_input" / "image_2" / f"{seq_idx:06d}_*.png"))) + for i in range(len(flows) - 1): + self._flow_list += [flows[i]] + self._image_list += [[images[i], images[i + 1]]] + else: + images1 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*10.png"))) + images2 = sorted(glob(str(root / "hd1k_challenge" / "image_2" / "*11.png"))) + for image1, image2 in zip(images1, images2): + self._image_list += [[image1, image2]] + + if not self._image_list: + raise FileNotFoundError( + "Could not find the HD1K images. Please make sure the directory structure is correct." + ) + + def _read_flow(self, file_name): + return _read_16bits_png_with_flow_and_valid_mask(file_name) + + def __getitem__(self, index): + """Return example at given index. + + Args: + index(int): The index of the example to retrieve + + Returns: + tuple: If ``split="train"`` a 4-tuple with ``(img1, img2, flow, + valid)`` where ``valid`` is a numpy boolean mask of shape (H, W) + indicating which flow values are valid. The flow is a numpy array of + shape (2, H, W) and the images are PIL images. If `split="test"`, a + 4-tuple with ``(img1, img2, None, None)`` is returned. + """ + return super().__getitem__(index) + + def _read_flo(file_name): """Read .flo file in Middlebury format""" # Code adapted from: