From 237a707373471cd74131b8c4ec4b5e0ab4df946d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 5 Jan 2022 15:02:30 +0000 Subject: [PATCH 01/12] Change default of download for Food101 and DTD --- torchvision/datasets/dtd.py | 4 ++-- torchvision/datasets/food101.py | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/torchvision/datasets/dtd.py b/torchvision/datasets/dtd.py index ceacc64eedb..1970cd6dce4 100644 --- a/torchvision/datasets/dtd.py +++ b/torchvision/datasets/dtd.py @@ -23,7 +23,7 @@ class DTD(VisionDataset): download (bool, optional): If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not - downloaded again. + downloaded again. Default is False. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. @@ -37,7 +37,7 @@ def __init__( root: str, split: str = "train", partition: int = 1, - download: bool = True, + download: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: diff --git a/torchvision/datasets/food101.py b/torchvision/datasets/food101.py index cffe0c50a06..fa194d56468 100644 --- a/torchvision/datasets/food101.py +++ b/torchvision/datasets/food101.py @@ -21,6 +21,9 @@ class Food101(VisionDataset): Args: root (string): Root directory of the dataset. split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. Default is False. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. @@ -33,7 +36,7 @@ def __init__( self, root: str, split: str = "train", - download: bool = True, + download: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: From 87695d4114b92113a1997edc7b6d7289743fe775 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 18 Jan 2022 10:58:15 +0000 Subject: [PATCH 02/12] Set download default to False and put it at the end --- torchvision/datasets/clevr.py | 2 +- torchvision/datasets/country211.py | 2 +- torchvision/datasets/dtd.py | 8 ++++---- torchvision/datasets/eurosat.py | 13 +++++++------ torchvision/datasets/fgvc_aircraft.py | 8 ++++---- torchvision/datasets/flowers102.py | 8 ++++---- torchvision/datasets/food101.py | 8 ++++---- torchvision/datasets/oxford_iiit_pet.py | 2 +- torchvision/datasets/pcam.py | 2 +- torchvision/datasets/sun397.py | 8 ++++---- 10 files changed, 31 insertions(+), 30 deletions(-) diff --git a/torchvision/datasets/clevr.py b/torchvision/datasets/clevr.py index 7ba5ca6cc47..112765a6b5d 100644 --- a/torchvision/datasets/clevr.py +++ b/torchvision/datasets/clevr.py @@ -34,7 +34,7 @@ def __init__( split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, - download: bool = True, + download: bool = False, ) -> None: self._split = verify_str_arg(split, "split", ("train", "val", "test")) super().__init__(root, transform=transform, target_transform=target_transform) diff --git a/torchvision/datasets/country211.py b/torchvision/datasets/country211.py index 20b69bc729e..b5c650cb276 100644 --- a/torchvision/datasets/country211.py +++ b/torchvision/datasets/country211.py @@ -32,7 +32,7 @@ def __init__( split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, - download: bool = True, + download: bool = False, ) -> None: self._split = verify_str_arg(split, "split", ("train", "valid", "test")) diff --git a/torchvision/datasets/dtd.py b/torchvision/datasets/dtd.py index 1970cd6dce4..deb27312573 100644 --- a/torchvision/datasets/dtd.py +++ b/torchvision/datasets/dtd.py @@ -21,12 +21,12 @@ class DTD(VisionDataset): The partition only changes which split each image belongs to. Thus, regardless of the selected partition, combining all splits will result in all images. - download (bool, optional): If True, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. Default is False. transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. Default is False. """ _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz" @@ -37,9 +37,9 @@ def __init__( root: str, split: str = "train", partition: int = 1, - download: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: self._split = verify_str_arg(split, "split", ("train", "val", "test")) if not isinstance(partition, int) and not (1 <= partition <= 10): diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index d7876b7afd5..4096d0e2c66 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -1,5 +1,5 @@ import os -from typing import Any +from typing import Callable, Optional from .folder import ImageFolder from .utils import download_and_extract_archive @@ -10,13 +10,13 @@ class EuroSAT(ImageFolder): Args: root (string): Root directory of dataset where ``root/eurosat`` exists. - download (bool, optional): If True, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. Default is False. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. Default is False. """ url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip" @@ -25,8 +25,9 @@ class EuroSAT(ImageFolder): def __init__( self, root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, download: bool = False, - **kwargs: Any, ) -> None: self.root = os.path.expanduser(root) self._base_folder = os.path.join(self.root, "eurosat") @@ -38,7 +39,7 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - super().__init__(self._data_folder, **kwargs) + super().__init__(self._data_folder, transform=transform, target_transform=target_transform) self.root = os.path.expanduser(root) def __len__(self) -> int: diff --git a/torchvision/datasets/fgvc_aircraft.py b/torchvision/datasets/fgvc_aircraft.py index 687d44fb7f0..d0bbf586639 100644 --- a/torchvision/datasets/fgvc_aircraft.py +++ b/torchvision/datasets/fgvc_aircraft.py @@ -26,15 +26,15 @@ class FGVCAircraft(VisionDataset): root (string): Root directory of the FGVC Aircraft dataset. split (string, optional): The dataset split, supports ``train``, ``val``, ``trainval`` and ``test``. - download (bool, optional): If True, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. annotation_level (str, optional): The annotation level, supports ``variant``, ``family`` and ``manufacturer``. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. """ _URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz" @@ -43,10 +43,10 @@ def __init__( self, root: str, split: str = "trainval", - download: bool = False, annotation_level: str = "variant", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test")) diff --git a/torchvision/datasets/flowers102.py b/torchvision/datasets/flowers102.py index 55347ffa550..8f4810e62e1 100644 --- a/torchvision/datasets/flowers102.py +++ b/torchvision/datasets/flowers102.py @@ -24,12 +24,12 @@ class Flowers102(VisionDataset): Args: root (string): Root directory of the dataset. split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. - download (bool, optional): If true, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. """ _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/" @@ -44,9 +44,9 @@ def __init__( self, root: str, split: str = "train", - download: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "val", "test")) diff --git a/torchvision/datasets/food101.py b/torchvision/datasets/food101.py index fa194d56468..1bb4d8094d5 100644 --- a/torchvision/datasets/food101.py +++ b/torchvision/datasets/food101.py @@ -21,12 +21,12 @@ class Food101(VisionDataset): Args: root (string): Root directory of the dataset. split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. - download (bool, optional): If True, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. Default is False. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. Default is False. """ _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz" @@ -36,9 +36,9 @@ def __init__( self, root: str, split: str = "train", - download: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "test")) diff --git a/torchvision/datasets/oxford_iiit_pet.py b/torchvision/datasets/oxford_iiit_pet.py index f7f77b997c2..733aa78256b 100644 --- a/torchvision/datasets/oxford_iiit_pet.py +++ b/torchvision/datasets/oxford_iiit_pet.py @@ -45,7 +45,7 @@ def __init__( transforms: Optional[Callable] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, - download: bool = True, + download: bool = False, ): self._split = verify_str_arg(split, "split", ("trainval", "test")) if isinstance(target_types, str): diff --git a/torchvision/datasets/pcam.py b/torchvision/datasets/pcam.py index f9b9b6817bf..7238931d1f3 100644 --- a/torchvision/datasets/pcam.py +++ b/torchvision/datasets/pcam.py @@ -72,7 +72,7 @@ def __init__( split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, - download: bool = True, + download: bool = False, ): try: import h5py # type: ignore[import] diff --git a/torchvision/datasets/sun397.py b/torchvision/datasets/sun397.py index da34351771f..2814ca80232 100644 --- a/torchvision/datasets/sun397.py +++ b/torchvision/datasets/sun397.py @@ -19,12 +19,12 @@ class SUN397(VisionDataset): split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. partition (int, optional): A valid partition can be an integer from 1 to 10 or None, for the entire dataset. - download (bool, optional): If true, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. """ _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz" @@ -37,9 +37,9 @@ def __init__( root: str, split: str = "train", partition: Optional[int] = 1, - download: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self.split = verify_str_arg(split, "split", ("train", "test")) From 1e6e37d860b38eebd89ca74de9b2ff104e401e61 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 18 Jan 2022 11:02:41 +0000 Subject: [PATCH 03/12] Keep stuff private --- torchvision/datasets/eurosat.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchvision/datasets/eurosat.py b/torchvision/datasets/eurosat.py index 4096d0e2c66..bec6df5312d 100644 --- a/torchvision/datasets/eurosat.py +++ b/torchvision/datasets/eurosat.py @@ -19,9 +19,6 @@ class EuroSAT(ImageFolder): downloaded again. Default is False. """ - url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip" - md5 = "c8fa014336c82ac7804f0398fcb19387" - def __init__( self, root: str, @@ -54,4 +51,8 @@ def download(self) -> None: return os.makedirs(self._base_folder, exist_ok=True) - download_and_extract_archive(self.url, download_root=self._base_folder, md5=self.md5) + download_and_extract_archive( + "https://madm.dfki.de/files/sentinel/EuroSAT.zip", + download_root=self._base_folder, + md5="c8fa014336c82ac7804f0398fcb19387", + ) From 474546fe8e82ebcc5b7387893244766e1205bfe1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 18 Jan 2022 11:23:44 +0000 Subject: [PATCH 04/12] GTSRB: train -> split. Also use pathlib --- test/test_datasets.py | 10 +++--- torchvision/datasets/gtsrb.py | 61 ++++++++++++++++------------------- 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index e306930aaf2..02876364651 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2397,17 +2397,17 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.GTSRB FEATURE_TYPES = (PIL.Image.Image, int) - ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False)) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test")) def inject_fake_data(self, tmpdir: str, config): - root_folder = os.path.join(tmpdir, "GTSRB") + root_folder = os.path.join(tmpdir, "gtsrb") os.makedirs(root_folder, exist_ok=True) # Train data - train_folder = os.path.join(root_folder, "Training") + train_folder = os.path.join(root_folder, "GTSRB", "Training") os.makedirs(train_folder, exist_ok=True) - num_examples = 3 + num_examples = 3 if config["split"] == "train" else 4 classes = ("00000", "00042", "00012") for class_idx in classes: datasets_utils.create_image_folder( @@ -2419,7 +2419,7 @@ def inject_fake_data(self, tmpdir: str, config): total_number_of_examples = num_examples * len(classes) # Test data - test_folder = os.path.join(root_folder, "Final_Test", "Images") + test_folder = os.path.join(root_folder, "GTSRB", "Final_Test", "Images") os.makedirs(test_folder, exist_ok=True) with open(os.path.join(root_folder, "GT-final_test.csv"), "w") as csv_file: diff --git a/torchvision/datasets/gtsrb.py b/torchvision/datasets/gtsrb.py index d970a0b472d..8a8ee3bdd77 100644 --- a/torchvision/datasets/gtsrb.py +++ b/torchvision/datasets/gtsrb.py @@ -1,11 +1,11 @@ import csv -import os +import pathlib from typing import Any, Callable, Optional, Tuple import PIL from .folder import make_dataset -from .utils import download_and_extract_archive +from .utils import download_and_extract_archive, verify_str_arg from .vision import VisionDataset @@ -14,8 +14,7 @@ class GTSRB(VisionDataset): Args: root (string): Root directory of the dataset. - train (bool, optional): If True, creates dataset from training set, otherwise - creates from test set. + split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. @@ -24,23 +23,10 @@ class GTSRB(VisionDataset): downloaded again. """ - # Ground Truth for the test set - _gt_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_GT.zip" - _gt_csv = "GT-final_test.csv" - _gt_md5 = "fe31e9c9270bbcd7b84b7f21a9d9d9e5" - - # URLs for the test and train set - _urls = ( - "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip", - "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip", - ) - - _md5s = ("c7e4e6327067d32654124b0fe9e82185", "513f3c79a4c5141765e10e952eaa2478") - def __init__( self, root: str, - train: bool = True, + split: str = "train", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, @@ -48,12 +34,11 @@ def __init__( super().__init__(root, transform=transform, target_transform=target_transform) - self.root = os.path.expanduser(root) - - self.train = train - - self._base_folder = os.path.join(self.root, type(self).__name__) - self._target_folder = os.path.join(self._base_folder, "Training" if self.train else "Final_Test/Images") + self._split = verify_str_arg(split, "split", ("train", "test")) + self._base_folder = pathlib.Path(root) / "gtsrb" + self._target_folder = ( + self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images") + ) if download: self.download() @@ -61,12 +46,12 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - if train: + if self._split == "train": samples = make_dataset(self._target_folder, extensions=(".ppm",)) else: - with open(os.path.join(self._base_folder, self._gt_csv)) as csv_file: + with open(self._base_folder / "GT-final_test.csv") as csv_file: samples = [ - (os.path.join(self._target_folder, row["Filename"]), int(row["ClassId"])) + (self._target_folder / row["Filename"], int(row["ClassId"])) for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True) ] @@ -91,16 +76,26 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: return sample, target def _check_exists(self) -> bool: - return os.path.exists(self._target_folder) and os.path.isdir(self._target_folder) + return self._target_folder.is_dir() def download(self) -> None: if self._check_exists(): return - download_and_extract_archive(self._urls[self.train], download_root=self.root, md5=self._md5s[self.train]) - - if not self.train: - # Download Ground Truth for the test set + if self._split == "train": + download_and_extract_archive( + "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip", + download_root=str(self._base_folder), + md5="513f3c79a4c5141765e10e952eaa2478", + ) + else: + download_and_extract_archive( + "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip", + download_root=str(self._base_folder), + md5="c7e4e6327067d32654124b0fe9e82185", + ) download_and_extract_archive( - self._gt_url, download_root=self.root, extract_root=self._base_folder, md5=self._gt_md5 + "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_GT.zip", + download_root=str(self._base_folder), + md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5", ) From a38a18b8c15486f8ae01960bf6bb7cd5f1901a5b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 18 Jan 2022 11:42:47 +0000 Subject: [PATCH 05/12] mypy --- mypy.ini | 4 ++++ torchvision/datasets/pcam.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mypy.ini b/mypy.ini index a6000f8a9d5..931665240f3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -117,3 +117,7 @@ ignore_missing_imports = True [mypy-torchdata.*] ignore_missing_imports = True + +[mypy-h5py.*] + +ignore_missing_imports = True diff --git a/torchvision/datasets/pcam.py b/torchvision/datasets/pcam.py index 7238931d1f3..4f124674961 100644 --- a/torchvision/datasets/pcam.py +++ b/torchvision/datasets/pcam.py @@ -75,7 +75,7 @@ def __init__( download: bool = False, ): try: - import h5py # type: ignore[import] + import h5py self.h5py = h5py except ImportError: From d58ef16d9ac4050955628b1db22f67113c8d824b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 18 Jan 2022 14:12:05 +0000 Subject: [PATCH 06/12] Remove split and partition for SUN397 --- test/test_datasets.py | 18 +----------------- torchvision/datasets/sun397.py | 30 ++++-------------------------- 2 files changed, 5 insertions(+), 43 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 02876364651..ca1579429be 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2281,11 +2281,6 @@ def inject_fake_data(self, tmpdir: str, config): class SUN397TestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.SUN397 - ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( - split=("train", "test"), - partition=(1, 10, None), - ) - def inject_fake_data(self, tmpdir: str, config): data_dir = pathlib.Path(tmpdir) / "SUN397" data_dir.mkdir() @@ -2308,18 +2303,7 @@ def inject_fake_data(self, tmpdir: str, config): with open(data_dir / "ClassName.txt", "w") as file: file.writelines("\n".join(f"/{cls[0]}/{cls}" for cls in sampled_classes)) - if config["partition"] is not None: - num_samples = max(len(im_paths) // (2 if config["split"] == "train" else 3), 1) - - with open(data_dir / f"{config['split'].title()}ing_{config['partition']:02d}.txt", "w") as file: - file.writelines( - "\n".join( - f"/{f_path.relative_to(data_dir).as_posix()}" - for f_path in random.choices(im_paths, k=num_samples) - ) - ) - else: - num_samples = len(im_paths) + num_samples = len(im_paths) return num_samples diff --git a/torchvision/datasets/sun397.py b/torchvision/datasets/sun397.py index 2814ca80232..cc3457fb16f 100644 --- a/torchvision/datasets/sun397.py +++ b/torchvision/datasets/sun397.py @@ -3,7 +3,7 @@ import PIL.Image -from .utils import verify_str_arg, download_and_extract_archive +from .utils import download_and_extract_archive from .vision import VisionDataset @@ -11,14 +11,10 @@ class SUN397(VisionDataset): """`The SUN397 Data Set `_. The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of - 397 categories with 108'754 images. The dataset also provides 10 partitions for training - and testing, with each partition consisting of 50 images per class. + 397 categories with 108'754 images. Args: root (string): Root directory of the dataset. - split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. - partition (int, optional): A valid partition can be an integer from 1 to 10 or None, - for the entire dataset. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. @@ -29,27 +25,17 @@ class SUN397(VisionDataset): _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz" _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a" - _PARTITIONS_URL = "https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip" - _PARTITIONS_MD5 = "29a205c0a0129d21f36cbecfefe81881" def __init__( self, root: str, - split: str = "train", - partition: Optional[int] = 1, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) - self.split = verify_str_arg(split, "split", ("train", "test")) - self.partition = partition self._data_dir = Path(self.root) / "SUN397" - if self.partition is not None: - if self.partition < 0 or self.partition > 10: - raise RuntimeError(f"The partition parameter should be an int in [1, 10] or None, got {partition}.") - if download: self._download() @@ -60,11 +46,7 @@ def __init__( self.classes = [c[3:].strip() for c in f] self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) - if self.partition is not None: - with open(self._data_dir / f"{self.split.title()}ing_{self.partition:02d}.txt", "r") as f: - self._image_files = [self._data_dir.joinpath(*line.strip()[1:].split("/")) for line in f] - else: - self._image_files = list(self._data_dir.rglob("sun_*.jpg")) + self._image_files = list(self._data_dir.rglob("sun_*.jpg")) self._labels = [ self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files @@ -86,13 +68,9 @@ def __getitem__(self, idx) -> Tuple[Any, Any]: return image, label def _check_exists(self) -> bool: - return self._data_dir.exists() and self._data_dir.is_dir() - - def extra_repr(self) -> str: - return "Split: {split}".format(**self.__dict__) + return self._data_dir.is_dir() def _download(self) -> None: if self._check_exists(): return download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5) - download_and_extract_archive(self._PARTITIONS_URL, download_root=str(self._data_dir), md5=self._PARTITIONS_MD5) From 5061141aedad6930f9d2eb269924180508047784 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 18 Jan 2022 14:25:24 +0000 Subject: [PATCH 07/12] mypy --- torchvision/datasets/gtsrb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/gtsrb.py b/torchvision/datasets/gtsrb.py index 8a8ee3bdd77..52a84ac7bca 100644 --- a/torchvision/datasets/gtsrb.py +++ b/torchvision/datasets/gtsrb.py @@ -47,7 +47,7 @@ def __init__( raise RuntimeError("Dataset not found. You can use download=True to download it") if self._split == "train": - samples = make_dataset(self._target_folder, extensions=(".ppm",)) + samples = make_dataset(str(self._target_folder), extensions=(".ppm",)) else: with open(self._base_folder / "GT-final_test.csv") as csv_file: samples = [ From 6c02cff8ff277f92cf686ddc6c3eb89b5684f825 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 18 Jan 2022 15:08:24 +0000 Subject: [PATCH 08/12] mypy --- torchvision/datasets/gtsrb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/gtsrb.py b/torchvision/datasets/gtsrb.py index 52a84ac7bca..9a6dd934aa5 100644 --- a/torchvision/datasets/gtsrb.py +++ b/torchvision/datasets/gtsrb.py @@ -51,7 +51,7 @@ def __init__( else: with open(self._base_folder / "GT-final_test.csv") as csv_file: samples = [ - (self._target_folder / row["Filename"], int(row["ClassId"])) + (str(self._target_folder / row["Filename"]), int(row["ClassId"])) for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True) ] From 194b55de1ede1940a917d547f885bcf76b61ac5d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 20 Jan 2022 10:23:28 +0000 Subject: [PATCH 09/12] move download param for SST2 --- torchvision/datasets/rendered_sst2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/datasets/rendered_sst2.py b/torchvision/datasets/rendered_sst2.py index 0001c6f6472..b33216855d1 100644 --- a/torchvision/datasets/rendered_sst2.py +++ b/torchvision/datasets/rendered_sst2.py @@ -21,12 +21,12 @@ class RenderedSST2(VisionDataset): Args: root (string): Root directory of the dataset. split (string, optional): The dataset split, supports ``"train"`` (default), `"val"` and ``"test"``. - download (bool, optional): If True, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. Default is False. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. Default is False. """ _URL = "https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz" @@ -36,9 +36,9 @@ def __init__( self, root: str, split: str = "train", - download: bool = False, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "val", "test")) From 78e52c57261f59ef4e8eeae1313581ee5fafb07d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 20 Jan 2022 10:34:46 +0000 Subject: [PATCH 10/12] Use make_dataset in SST2 --- torchvision/datasets/rendered_sst2.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/torchvision/datasets/rendered_sst2.py b/torchvision/datasets/rendered_sst2.py index b33216855d1..56cd474ff18 100644 --- a/torchvision/datasets/rendered_sst2.py +++ b/torchvision/datasets/rendered_sst2.py @@ -3,6 +3,7 @@ import PIL.Image +from .folder import make_dataset from .utils import verify_str_arg, download_and_extract_archive from .vision import VisionDataset @@ -53,18 +54,13 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - self._labels = [] - self._image_files = [] - - for p in (self._base_folder / self._split_to_folder[self._split]).glob("**/*.png"): - self._labels.append(self.class_to_idx[p.parent.name]) - self._image_files.append(p) + self._samples = make_dataset(str(self._base_folder / self._split_to_folder[self._split]), extensions="png") def __len__(self) -> int: - return len(self._image_files) + return len(self._samples) def __getitem__(self, idx) -> Tuple[Any, Any]: - image_file, label = self._image_files[idx], self._labels[idx] + image_file, label = self._samples[idx] image = PIL.Image.open(image_file).convert("RGB") if self.transform: From dc7c1662ddb50b78129a039dc4c73e5daa1230bf Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 20 Jan 2022 10:36:46 +0000 Subject: [PATCH 11/12] Use a base URL for GTSRB --- torchvision/datasets/gtsrb.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/datasets/gtsrb.py b/torchvision/datasets/gtsrb.py index 9a6dd934aa5..f99a688586d 100644 --- a/torchvision/datasets/gtsrb.py +++ b/torchvision/datasets/gtsrb.py @@ -82,20 +82,22 @@ def download(self) -> None: if self._check_exists(): return + base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" + if self._split == "train": download_and_extract_archive( - "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip", + f"{base_url}GTSRB-Training_fixed.zip", download_root=str(self._base_folder), md5="513f3c79a4c5141765e10e952eaa2478", ) else: download_and_extract_archive( - "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip", + f"{base_url}GTSRB_Final_Test_Images.zip", download_root=str(self._base_folder), md5="c7e4e6327067d32654124b0fe9e82185", ) download_and_extract_archive( - "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_GT.zip", + f"{base_url}GTSRB_Final_Test_GT.zip", download_root=str(self._base_folder), md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5", ) From 3c70d81164b5e747fe37bcfeff474b4bd07bb33b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 20 Jan 2022 10:52:46 +0000 Subject: [PATCH 12/12] Let's make this code more complictaed than it needs to be because why not --- torchvision/datasets/rendered_sst2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/rendered_sst2.py b/torchvision/datasets/rendered_sst2.py index 56cd474ff18..02445dddb05 100644 --- a/torchvision/datasets/rendered_sst2.py +++ b/torchvision/datasets/rendered_sst2.py @@ -54,7 +54,7 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - self._samples = make_dataset(str(self._base_folder / self._split_to_folder[self._split]), extensions="png") + self._samples = make_dataset(str(self._base_folder / self._split_to_folder[self._split]), extensions=("png",)) def __len__(self) -> int: return len(self._samples)