@@ -57,67 +57,6 @@ def generic_segmentation_dataset_test(self, dataset, num_images=1):
57
57
58
58
59
59
class Tester (DatasetTestcase ):
60
- def test_imagefolder (self ):
61
- # TODO: create the fake data on-the-fly
62
- FAKEDATA_DIR = get_file_path_2 (
63
- os .path .dirname (os .path .abspath (__file__ )), 'assets' , 'fakedata' )
64
-
65
- with get_tmp_dir (src = os .path .join (FAKEDATA_DIR , 'imagefolder' )) as root :
66
- classes = sorted (['a' , 'b' ])
67
- class_a_image_files = [
68
- os .path .join (root , 'a' , file ) for file in ('a1.png' , 'a2.png' , 'a3.png' )
69
- ]
70
- class_b_image_files = [
71
- os .path .join (root , 'b' , file ) for file in ('b1.png' , 'b2.png' , 'b3.png' , 'b4.png' )
72
- ]
73
- dataset = torchvision .datasets .ImageFolder (root , loader = lambda x : x )
74
-
75
- # test if all classes are present
76
- self .assertEqual (classes , sorted (dataset .classes ))
77
-
78
- # test if combination of classes and class_to_index functions correctly
79
- for cls in classes :
80
- self .assertEqual (cls , dataset .classes [dataset .class_to_idx [cls ]])
81
-
82
- # test if all images were detected correctly
83
- class_a_idx = dataset .class_to_idx ['a' ]
84
- class_b_idx = dataset .class_to_idx ['b' ]
85
- imgs_a = [(img_file , class_a_idx ) for img_file in class_a_image_files ]
86
- imgs_b = [(img_file , class_b_idx ) for img_file in class_b_image_files ]
87
- imgs = sorted (imgs_a + imgs_b )
88
- self .assertEqual (imgs , dataset .imgs )
89
-
90
- # test if the datasets outputs all images correctly
91
- outputs = sorted ([dataset [i ] for i in range (len (dataset ))])
92
- self .assertEqual (imgs , outputs )
93
-
94
- # redo all tests with specified valid image files
95
- dataset = torchvision .datasets .ImageFolder (
96
- root , loader = lambda x : x , is_valid_file = lambda x : '3' in x )
97
- self .assertEqual (classes , sorted (dataset .classes ))
98
-
99
- class_a_idx = dataset .class_to_idx ['a' ]
100
- class_b_idx = dataset .class_to_idx ['b' ]
101
- imgs_a = [(img_file , class_a_idx ) for img_file in class_a_image_files
102
- if '3' in img_file ]
103
- imgs_b = [(img_file , class_b_idx ) for img_file in class_b_image_files
104
- if '3' in img_file ]
105
- imgs = sorted (imgs_a + imgs_b )
106
- self .assertEqual (imgs , dataset .imgs )
107
-
108
- outputs = sorted ([dataset [i ] for i in range (len (dataset ))])
109
- self .assertEqual (imgs , outputs )
110
-
111
- def test_imagefolder_empty (self ):
112
- with get_tmp_dir () as root :
113
- with self .assertRaises (FileNotFoundError ):
114
- torchvision .datasets .ImageFolder (root , loader = lambda x : x )
115
-
116
- with self .assertRaises (FileNotFoundError ):
117
- torchvision .datasets .ImageFolder (
118
- root , loader = lambda x : x , is_valid_file = lambda x : False
119
- )
120
-
121
60
@mock .patch ('torchvision.datasets.SVHN._check_integrity' )
122
61
@unittest .skipIf (not HAS_SCIPY , "scipy unavailable" )
123
62
def test_svhn (self , mock_check ):
@@ -1673,5 +1612,95 @@ def test_num_examples_test50k(self):
1673
1612
self .assertEqual (len (dataset ), info ["num_examples" ] - 10000 )
1674
1613
1675
1614
1615
+ class DatasetFolderTestCase (datasets_utils .ImageDatasetTestCase ):
1616
+ DATASET_CLASS = datasets .DatasetFolder
1617
+
1618
+ # The dataset has no fixed return type since it is defined by the loader parameter. For testing, we use a loader
1619
+ # that simply returns the path as type 'str' instead of loading anything. See the 'dataset_args()' method.
1620
+ FEATURE_TYPES = (str , int )
1621
+
1622
+ _IMAGE_EXTENSIONS = ("jpg" , "png" )
1623
+ _VIDEO_EXTENSIONS = ("avi" , "mp4" )
1624
+ _EXTENSIONS = (* _IMAGE_EXTENSIONS , * _VIDEO_EXTENSIONS )
1625
+
1626
+ # DatasetFolder has two mutually exclusive parameters: 'extensions' and 'is_valid_file'. One of both is required.
1627
+ # We only iterate over different 'extensions' here and handle the tests for 'is_valid_file' in the
1628
+ # 'test_is_valid_file()' method.
1629
+ DEFAULT_CONFIG = dict (extensions = _EXTENSIONS )
1630
+ ADDITIONAL_CONFIGS = (
1631
+ * datasets_utils .combinations_grid (extensions = [(ext ,) for ext in _IMAGE_EXTENSIONS ]),
1632
+ dict (extensions = _IMAGE_EXTENSIONS ),
1633
+ * datasets_utils .combinations_grid (extensions = [(ext ,) for ext in _VIDEO_EXTENSIONS ]),
1634
+ dict (extensions = _VIDEO_EXTENSIONS ),
1635
+ )
1636
+
1637
+ def dataset_args (self , tmpdir , config ):
1638
+ return tmpdir , lambda x : x
1639
+
1640
+ def inject_fake_data (self , tmpdir , config ):
1641
+ extensions = config ["extensions" ] or self ._is_valid_file_to_extensions (config ["is_valid_file" ])
1642
+
1643
+ num_examples_total = 0
1644
+ classes = []
1645
+ for ext , cls in zip (self ._EXTENSIONS , string .ascii_letters ):
1646
+ if ext not in extensions :
1647
+ continue
1648
+
1649
+ create_example_folder = (
1650
+ datasets_utils .create_image_folder
1651
+ if ext in self ._IMAGE_EXTENSIONS
1652
+ else datasets_utils .create_video_folder
1653
+ )
1654
+
1655
+ num_examples = torch .randint (1 , 3 , size = ()).item ()
1656
+ create_example_folder (tmpdir , cls , lambda idx : self ._file_name_fn (cls , ext , idx ), num_examples )
1657
+
1658
+ num_examples_total += num_examples
1659
+ classes .append (cls )
1660
+
1661
+ return dict (num_examples = num_examples_total , classes = classes )
1662
+
1663
+ def _file_name_fn (self , cls , ext , idx ):
1664
+ return f"{ cls } _{ idx } .{ ext } "
1665
+
1666
+ def _is_valid_file_to_extensions (self , is_valid_file ):
1667
+ return {ext for ext in self ._EXTENSIONS if is_valid_file (f"foo.{ ext } " )}
1668
+
1669
+ @datasets_utils .test_all_configs
1670
+ def test_is_valid_file (self , config ):
1671
+ extensions = config .pop ("extensions" )
1672
+ # We need to explicitly pass extensions=None here or otherwise it would be filled by the value from the
1673
+ # DEFAULT_CONFIG.
1674
+ with self .create_dataset (
1675
+ config , extensions = None , is_valid_file = lambda file : pathlib .Path (file ).suffix [1 :] in extensions
1676
+ ) as (dataset , info ):
1677
+ self .assertEqual (len (dataset ), info ["num_examples" ])
1678
+
1679
+ @datasets_utils .test_all_configs
1680
+ def test_classes (self , config ):
1681
+ with self .create_dataset (config ) as (dataset , info ):
1682
+ self .assertSequenceEqual (dataset .classes , info ["classes" ])
1683
+
1684
+
1685
+ class ImageFolderTestCase (datasets_utils .ImageDatasetTestCase ):
1686
+ DATASET_CLASS = datasets .ImageFolder
1687
+
1688
+ def inject_fake_data (self , tmpdir , config ):
1689
+ num_examples_total = 0
1690
+ classes = ("a" , "b" )
1691
+ for cls in classes :
1692
+ num_examples = torch .randint (1 , 3 , size = ()).item ()
1693
+ num_examples_total += num_examples
1694
+
1695
+ datasets_utils .create_image_folder (tmpdir , cls , lambda idx : f"{ cls } _{ idx } .png" , num_examples )
1696
+
1697
+ return dict (num_examples = num_examples_total , classes = classes )
1698
+
1699
+ @datasets_utils .test_all_configs
1700
+ def test_classes (self , config ):
1701
+ with self .create_dataset (config ) as (dataset , info ):
1702
+ self .assertSequenceEqual (dataset .classes , info ["classes" ])
1703
+
1704
+
1676
1705
if __name__ == "__main__" :
1677
1706
unittest .main ()
0 commit comments