Skip to content

Commit 315f1a2

Browse files
prabhat00155facebook-github-bot
authored andcommitted
[fbsync] Add Kitti and Sintel datasets for optical flow (#4845)
Reviewed By: kazhang Differential Revision: D32216685 fbshipit-source-id: ec74c2a573eace36bd4a0cf9913ea1dc77fcf260
1 parent bddc464 commit 315f1a2

File tree

5 files changed

+373
-1
lines changed

5 files changed

+373
-1
lines changed

docs/source/datasets.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
4848
INaturalist
4949
Kinetics400
5050
Kitti
51+
KittiFlow
5152
KMNIST
5253
LFWPeople
5354
LFWPairs
@@ -60,6 +61,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
6061
SBDataset
6162
SBU
6263
SEMEION
64+
Sintel
6365
STL10
6466
SVHN
6567
UCF101

test/datasets_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ class DatasetTestCase(unittest.TestCase):
203203
``transforms``, or ``download``.
204204
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
205205
available, the tests are skipped.
206+
- EXTRA_PATCHES(set): Additional patches to add for each test, to e.g. mock a specific function
206207
207208
Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on.
208209
The fake data should resemble the original data as close as necessary, while containing only few examples. During
@@ -254,6 +255,8 @@ def test_baz(self):
254255
ADDITIONAL_CONFIGS = None
255256
REQUIRED_PACKAGES = None
256257

258+
EXTRA_PATCHES = None
259+
257260
# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
258261
_TRANSFORM_KWARGS = {
259262
"transform",
@@ -379,14 +382,17 @@ def create_dataset(
379382
if patch_checks:
380383
patchers.update(self._patch_checks())
381384

385+
if self.EXTRA_PATCHES is not None:
386+
patchers.update(self.EXTRA_PATCHES)
387+
382388
with get_tmp_dir() as tmpdir:
383389
args = self.dataset_args(tmpdir, complete_config)
384390
info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None
385391

386392
with self._maybe_apply_patches(patchers), disable_console_output():
387393
dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs)
388394

389-
yield dataset, info
395+
yield dataset, info
390396

391397
@classmethod
392398
def setUpClass(cls):

test/test_datasets.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1871,5 +1871,132 @@ def _inject_pairs(self, root, num_pairs, same):
18711871
datasets_utils.create_image_folder(root, name2, lambda _: f"{name2}_{no2:04d}.jpg", 1, 250)
18721872

18731873

1874+
class SintelTestCase(datasets_utils.ImageDatasetTestCase):
1875+
DATASET_CLASS = datasets.Sintel
1876+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final"))
1877+
# We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files,
1878+
# which is something we want to # avoid.
1879+
_FAKE_FLOW = "Fake Flow"
1880+
EXTRA_PATCHES = {unittest.mock.patch("torchvision.datasets.Sintel._read_flow", return_value=_FAKE_FLOW)}
1881+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (type(_FAKE_FLOW), type(None)))
1882+
1883+
def inject_fake_data(self, tmpdir, config):
1884+
root = pathlib.Path(tmpdir) / "Sintel"
1885+
1886+
num_images_per_scene = 3 if config["split"] == "train" else 4
1887+
num_scenes = 2
1888+
1889+
for split_dir in ("training", "test"):
1890+
for pass_name in ("clean", "final"):
1891+
image_root = root / split_dir / pass_name
1892+
1893+
for scene_id in range(num_scenes):
1894+
scene_dir = image_root / f"scene_{scene_id}"
1895+
datasets_utils.create_image_folder(
1896+
image_root,
1897+
name=str(scene_dir),
1898+
file_name_fn=lambda image_idx: f"frame_000{image_idx}.png",
1899+
num_examples=num_images_per_scene,
1900+
)
1901+
1902+
# For the ground truth flow value we just create empty files so that they're properly discovered,
1903+
# see comment above about EXTRA_PATCHES
1904+
flow_root = root / "training" / "flow"
1905+
for scene_id in range(num_scenes):
1906+
scene_dir = flow_root / f"scene_{scene_id}"
1907+
os.makedirs(scene_dir)
1908+
for i in range(num_images_per_scene - 1):
1909+
open(str(scene_dir / f"frame_000{i}.flo"), "a").close()
1910+
1911+
# with e.g. num_images_per_scene = 3, for a single scene with have 3 images
1912+
# which are frame_0000, frame_0001 and frame_0002
1913+
# They will be consecutively paired as (frame_0000, frame_0001), (frame_0001, frame_0002),
1914+
# that is 3 - 1 = 2 examples. Hence the formula below
1915+
num_examples = (num_images_per_scene - 1) * num_scenes
1916+
return num_examples
1917+
1918+
def test_flow(self):
1919+
# Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images
1920+
with self.create_dataset(split="train") as (dataset, _):
1921+
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
1922+
for _, _, flow in dataset:
1923+
assert flow == self._FAKE_FLOW
1924+
1925+
# Make sure flow is always None for test split
1926+
with self.create_dataset(split="test") as (dataset, _):
1927+
assert dataset._image_list and not dataset._flow_list
1928+
for _, _, flow in dataset:
1929+
assert flow is None
1930+
1931+
def test_bad_input(self):
1932+
with pytest.raises(ValueError, match="split must be either"):
1933+
with self.create_dataset(split="bad"):
1934+
pass
1935+
1936+
with pytest.raises(ValueError, match="pass_name must be either"):
1937+
with self.create_dataset(pass_name="bad"):
1938+
pass
1939+
1940+
1941+
class KittiFlowTestCase(datasets_utils.ImageDatasetTestCase):
1942+
DATASET_CLASS = datasets.KittiFlow
1943+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))
1944+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)), (np.ndarray, type(None)))
1945+
1946+
def inject_fake_data(self, tmpdir, config):
1947+
root = pathlib.Path(tmpdir) / "Kitti"
1948+
1949+
num_examples = 2 if config["split"] == "train" else 3
1950+
for split_dir in ("training", "testing"):
1951+
1952+
datasets_utils.create_image_folder(
1953+
root / split_dir,
1954+
name="image_2",
1955+
file_name_fn=lambda image_idx: f"{image_idx}_10.png",
1956+
num_examples=num_examples,
1957+
)
1958+
datasets_utils.create_image_folder(
1959+
root / split_dir,
1960+
name="image_2",
1961+
file_name_fn=lambda image_idx: f"{image_idx}_11.png",
1962+
num_examples=num_examples,
1963+
)
1964+
1965+
# For kitti the ground truth flows are encoded as 16-bits pngs.
1966+
# create_image_folder() will actually create 8-bits pngs, but it doesn't
1967+
# matter much: the flow reader will still be able to read the files, it
1968+
# will just be garbage flow value - but we don't care about that here.
1969+
datasets_utils.create_image_folder(
1970+
root / "training",
1971+
name="flow_occ",
1972+
file_name_fn=lambda image_idx: f"{image_idx}_10.png",
1973+
num_examples=num_examples,
1974+
)
1975+
1976+
return num_examples
1977+
1978+
def test_flow_and_valid(self):
1979+
# Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images
1980+
# Also assert flow and valid are of the expected shape
1981+
with self.create_dataset(split="train") as (dataset, _):
1982+
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
1983+
for _, _, flow, valid in dataset:
1984+
two, h, w = flow.shape
1985+
assert two == 2
1986+
assert valid.shape == (h, w)
1987+
1988+
# Make sure flow and valid are always None for test split
1989+
with self.create_dataset(split="test") as (dataset, _):
1990+
assert dataset._image_list and not dataset._flow_list
1991+
for _, _, flow, valid in dataset:
1992+
assert flow is None
1993+
assert valid is None
1994+
1995+
def test_bad_input(self):
1996+
with pytest.raises(ValueError, match="split must be either"):
1997+
with self.create_dataset(split="bad"):
1998+
pass
1999+
2000+
18742001
if __name__ == "__main__":
18752002
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from ._optical_flow import KittiFlow, Sintel
12
from .caltech import Caltech101, Caltech256
23
from .celeba import CelebA
34
from .cifar import CIFAR10, CIFAR100
@@ -71,4 +72,6 @@
7172
"INaturalist",
7273
"LFWPeople",
7374
"LFWPairs",
75+
"KittiFlow",
76+
"Sintel",
7477
)

0 commit comments

Comments
 (0)