Skip to content

Commit 20a771e

Browse files
authored
add tests for (Dataset|Image)Folder (#3477)
* add tests for (Dataset|Image)Folder * lint * remove old tests * cleanup * more cleanup * adapt tests * fix make_dataset * remove powerset * readd import
1 parent 7cc941f commit 20a771e

File tree

2 files changed

+91
-62
lines changed

2 files changed

+91
-62
lines changed

test/test_datasets.py

Lines changed: 90 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -57,67 +57,6 @@ def generic_segmentation_dataset_test(self, dataset, num_images=1):
5757

5858

5959
class Tester(DatasetTestcase):
60-
def test_imagefolder(self):
61-
# TODO: create the fake data on-the-fly
62-
FAKEDATA_DIR = get_file_path_2(
63-
os.path.dirname(os.path.abspath(__file__)), 'assets', 'fakedata')
64-
65-
with get_tmp_dir(src=os.path.join(FAKEDATA_DIR, 'imagefolder')) as root:
66-
classes = sorted(['a', 'b'])
67-
class_a_image_files = [
68-
os.path.join(root, 'a', file) for file in ('a1.png', 'a2.png', 'a3.png')
69-
]
70-
class_b_image_files = [
71-
os.path.join(root, 'b', file) for file in ('b1.png', 'b2.png', 'b3.png', 'b4.png')
72-
]
73-
dataset = torchvision.datasets.ImageFolder(root, loader=lambda x: x)
74-
75-
# test if all classes are present
76-
self.assertEqual(classes, sorted(dataset.classes))
77-
78-
# test if combination of classes and class_to_index functions correctly
79-
for cls in classes:
80-
self.assertEqual(cls, dataset.classes[dataset.class_to_idx[cls]])
81-
82-
# test if all images were detected correctly
83-
class_a_idx = dataset.class_to_idx['a']
84-
class_b_idx = dataset.class_to_idx['b']
85-
imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files]
86-
imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files]
87-
imgs = sorted(imgs_a + imgs_b)
88-
self.assertEqual(imgs, dataset.imgs)
89-
90-
# test if the datasets outputs all images correctly
91-
outputs = sorted([dataset[i] for i in range(len(dataset))])
92-
self.assertEqual(imgs, outputs)
93-
94-
# redo all tests with specified valid image files
95-
dataset = torchvision.datasets.ImageFolder(
96-
root, loader=lambda x: x, is_valid_file=lambda x: '3' in x)
97-
self.assertEqual(classes, sorted(dataset.classes))
98-
99-
class_a_idx = dataset.class_to_idx['a']
100-
class_b_idx = dataset.class_to_idx['b']
101-
imgs_a = [(img_file, class_a_idx) for img_file in class_a_image_files
102-
if '3' in img_file]
103-
imgs_b = [(img_file, class_b_idx) for img_file in class_b_image_files
104-
if '3' in img_file]
105-
imgs = sorted(imgs_a + imgs_b)
106-
self.assertEqual(imgs, dataset.imgs)
107-
108-
outputs = sorted([dataset[i] for i in range(len(dataset))])
109-
self.assertEqual(imgs, outputs)
110-
111-
def test_imagefolder_empty(self):
112-
with get_tmp_dir() as root:
113-
with self.assertRaises(FileNotFoundError):
114-
torchvision.datasets.ImageFolder(root, loader=lambda x: x)
115-
116-
with self.assertRaises(FileNotFoundError):
117-
torchvision.datasets.ImageFolder(
118-
root, loader=lambda x: x, is_valid_file=lambda x: False
119-
)
120-
12160
@mock.patch('torchvision.datasets.SVHN._check_integrity')
12261
@unittest.skipIf(not HAS_SCIPY, "scipy unavailable")
12362
def test_svhn(self, mock_check):
@@ -1673,5 +1612,95 @@ def test_num_examples_test50k(self):
16731612
self.assertEqual(len(dataset), info["num_examples"] - 10000)
16741613

16751614

1615+
class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
1616+
DATASET_CLASS = datasets.DatasetFolder
1617+
1618+
# The dataset has no fixed return type since it is defined by the loader parameter. For testing, we use a loader
1619+
# that simply returns the path as type 'str' instead of loading anything. See the 'dataset_args()' method.
1620+
FEATURE_TYPES = (str, int)
1621+
1622+
_IMAGE_EXTENSIONS = ("jpg", "png")
1623+
_VIDEO_EXTENSIONS = ("avi", "mp4")
1624+
_EXTENSIONS = (*_IMAGE_EXTENSIONS, *_VIDEO_EXTENSIONS)
1625+
1626+
# DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required.
1627+
# We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the
1628+
# 'test_is_valid_file()' method.
1629+
DEFAULT_CONFIG = dict(extensions=_EXTENSIONS)
1630+
ADDITIONAL_CONFIGS = (
1631+
*datasets_utils.combinations_grid(extensions=[(ext,) for ext in _IMAGE_EXTENSIONS]),
1632+
dict(extensions=_IMAGE_EXTENSIONS),
1633+
*datasets_utils.combinations_grid(extensions=[(ext,) for ext in _VIDEO_EXTENSIONS]),
1634+
dict(extensions=_VIDEO_EXTENSIONS),
1635+
)
1636+
1637+
def dataset_args(self, tmpdir, config):
1638+
return tmpdir, lambda x: x
1639+
1640+
def inject_fake_data(self, tmpdir, config):
1641+
extensions = config["extensions"] or self._is_valid_file_to_extensions(config["is_valid_file"])
1642+
1643+
num_examples_total = 0
1644+
classes = []
1645+
for ext, cls in zip(self._EXTENSIONS, string.ascii_letters):
1646+
if ext not in extensions:
1647+
continue
1648+
1649+
create_example_folder = (
1650+
datasets_utils.create_image_folder
1651+
if ext in self._IMAGE_EXTENSIONS
1652+
else datasets_utils.create_video_folder
1653+
)
1654+
1655+
num_examples = torch.randint(1, 3, size=()).item()
1656+
create_example_folder(tmpdir, cls, lambda idx: self._file_name_fn(cls, ext, idx), num_examples)
1657+
1658+
num_examples_total += num_examples
1659+
classes.append(cls)
1660+
1661+
return dict(num_examples=num_examples_total, classes=classes)
1662+
1663+
def _file_name_fn(self, cls, ext, idx):
1664+
return f"{cls}_{idx}.{ext}"
1665+
1666+
def _is_valid_file_to_extensions(self, is_valid_file):
1667+
return {ext for ext in self._EXTENSIONS if is_valid_file(f"foo.{ext}")}
1668+
1669+
@datasets_utils.test_all_configs
1670+
def test_is_valid_file(self, config):
1671+
extensions = config.pop("extensions")
1672+
# We need to explicitly pass extensions=None here or otherwise it would be filled by the value from the
1673+
# DEFAULT_CONFIG.
1674+
with self.create_dataset(
1675+
config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions
1676+
) as (dataset, info):
1677+
self.assertEqual(len(dataset), info["num_examples"])
1678+
1679+
@datasets_utils.test_all_configs
1680+
def test_classes(self, config):
1681+
with self.create_dataset(config) as (dataset, info):
1682+
self.assertSequenceEqual(dataset.classes, info["classes"])
1683+
1684+
1685+
class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase):
1686+
DATASET_CLASS = datasets.ImageFolder
1687+
1688+
def inject_fake_data(self, tmpdir, config):
1689+
num_examples_total = 0
1690+
classes = ("a", "b")
1691+
for cls in classes:
1692+
num_examples = torch.randint(1, 3, size=()).item()
1693+
num_examples_total += num_examples
1694+
1695+
datasets_utils.create_image_folder(tmpdir, cls, lambda idx: f"{cls}_{idx}.png", num_examples)
1696+
1697+
return dict(num_examples=num_examples_total, classes=classes)
1698+
1699+
@datasets_utils.test_all_configs
1700+
def test_classes(self, config):
1701+
with self.create_dataset(config) as (dataset, info):
1702+
self.assertSequenceEqual(dataset.classes, info["classes"])
1703+
1704+
16761705
if __name__ == "__main__":
16771706
unittest.main()

torchvision/datasets/folder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def is_valid_file(x: str) -> bool:
129129
if target_class not in available_classes:
130130
available_classes.add(target_class)
131131

132-
empty_classes = available_classes - set(class_to_idx.keys())
132+
empty_classes = set(class_to_idx.keys()) - available_classes
133133
if empty_classes:
134134
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
135135
if extensions is not None:

0 commit comments

Comments
 (0)