diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index d121bad7a19..9eb849bbe34 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -182,7 +182,7 @@ def __init__( ) -> None: super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) - classes, class_to_idx = self._find_classes(self.root) + classes, class_to_idx = self.find_classes(self.root) samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) self.loader = loader @@ -202,8 +202,12 @@ def make_dataset( ) -> List[Tuple[str, int]]: return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) - @staticmethod - def _find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: + def find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: + """Same as :func:`find_classes`. + + This method can be overridden to only consider + a subset of classes, or to adapt to a different dataset directory structure. + """ return find_classes(dir) def __getitem__(self, index: int) -> Tuple[Any, Any]: