From 02dc5f77b5bcab36e36bd5136087ec3fb6288eaf Mon Sep 17 00:00:00 2001 From: puhuk Date: Fri, 31 Dec 2021 01:03:38 +0900 Subject: [PATCH 01/10] Add Country211 dataset To addresses issue #5108. --- torchvision/datasets/__init__.py | 2 + torchvision/datasets/country211.py | 74 ++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 torchvision/datasets/country211.py diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 1c4e24d9d0b..dfd04e7041c 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -29,6 +29,7 @@ from .vision import VisionDataset from .voc import VOCSegmentation, VOCDetection from .widerface import WIDERFace +from .country211 import Country211 __all__ = ( "LSUN", @@ -79,4 +80,5 @@ "FlyingThings3D", "HD1K", "Food101", + "Country211", ) diff --git a/torchvision/datasets/country211.py b/torchvision/datasets/country211.py new file mode 100644 index 00000000000..41228399268 --- /dev/null +++ b/torchvision/datasets/country211.py @@ -0,0 +1,74 @@ +import json +from pathlib import Path +from typing import Any, Tuple, Callable, Optional + +import PIL.Image + +from .utils import verify_str_arg, download_and_extract_archive +from .vision import VisionDataset +from .folder import find_classes, make_dataset + +class Country211(VisionDataset): + """`The Country211 Data Set `_. + + filtered the YFCC100m dataset that have GPS coordinate corresponding to a ISO-3166 country code + and created a balanced dataset by sampling 150 train images, 50 validation images, + and 100 test images images for each country. + + + Args: + root (string): Root directory of the dataset. + split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. + transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop``. + target_transform (callable, optional): A function/transform that takes in the target and transforms it. + """ + + _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz" + + def __init__( + self, + root: str, + split: str = "train", + download: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + extensions: Tuple[str, ...] = ("jpg", "png"), + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self._split = verify_str_arg(split, "split", ("train", "valid", "test")) + self._base_folder = Path(self.root) / "country211" + + if download: + self._download() + + self._labels = [] + self._image_files = [] + + self.split_folder = self._base_folder / self._split + + self.classes, class_to_idx = find_classes(self.split_folder) + self.samples = make_dataset(self.split_folder, class_to_idx, extensions, is_valid_file=None) + + def __len__(self) -> int: + return len(self._image_files) + + def __getitem__(self, idx) -> Tuple[Any, Any]: + image_file, label = self.samples[idx][0], self.samples[idx][1] + image = PIL.Image.open(image_file).convert("RGB") + + if self.transform: + image = self.transform(image) + + if self.target_transform: + label = self.target_transform(label) + + return image, label + + def _check_exists(self) -> bool: + return all(folder.exists() and folder.is_dir() for folder in (Path(self.root), self._base_folder, self._images_folder)) + + def _download(self) -> None: + if self._check_exists(): + return + download_and_extract_archive(self._URL, download_root=self.root, md5=None) \ No newline at end of file From 83921ec9c93a5a25452e201a86759aa34a2357a3 Mon Sep 17 00:00:00 2001 From: puhuk Date: Fri, 31 Dec 2021 01:06:18 +0900 Subject: [PATCH 02/10] Add Country211 dataset To addresses issue #5108. --- torchvision/datasets/__init__.py | 2 +- torchvision/datasets/country211.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index dfd04e7041c..be42130ed18 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -4,6 +4,7 @@ from .cifar import CIFAR10, CIFAR100 from .cityscapes import Cityscapes from .coco import CocoCaptions, CocoDetection +from .country211 import Country211 from .fakedata import FakeData from .flickr import Flickr8k, Flickr30k from .folder import ImageFolder, DatasetFolder @@ -29,7 +30,6 @@ from .vision import VisionDataset from .voc import VOCSegmentation, VOCDetection from .widerface import WIDERFace -from .country211 import Country211 __all__ = ( "LSUN", diff --git a/torchvision/datasets/country211.py b/torchvision/datasets/country211.py index 41228399268..7b10cc2539e 100644 --- a/torchvision/datasets/country211.py +++ b/torchvision/datasets/country211.py @@ -4,15 +4,16 @@ import PIL.Image +from .folder import find_classes, make_dataset from .utils import verify_str_arg, download_and_extract_archive from .vision import VisionDataset -from .folder import find_classes, make_dataset + class Country211(VisionDataset): """`The Country211 Data Set `_. - filtered the YFCC100m dataset that have GPS coordinate corresponding to a ISO-3166 country code - and created a balanced dataset by sampling 150 train images, 50 validation images, + filtered the YFCC100m dataset that have GPS coordinate corresponding to a ISO-3166 country code + and created a balanced dataset by sampling 150 train images, 50 validation images, and 100 test images images for each country. @@ -44,7 +45,7 @@ def __init__( self._labels = [] self._image_files = [] - + self.split_folder = self._base_folder / self._split self.classes, class_to_idx = find_classes(self.split_folder) @@ -66,9 +67,11 @@ def __getitem__(self, idx) -> Tuple[Any, Any]: return image, label def _check_exists(self) -> bool: - return all(folder.exists() and folder.is_dir() for folder in (Path(self.root), self._base_folder, self._images_folder)) + return all( + folder.exists() and folder.is_dir() for folder in (Path(self.root), self._base_folder, self._images_folder) + ) def _download(self) -> None: if self._check_exists(): return - download_and_extract_archive(self._URL, download_root=self.root, md5=None) \ No newline at end of file + download_and_extract_archive(self._URL, download_root=self.root, md5=None) From 43869e72495364aaf217a68ba6b113ddbbde564d Mon Sep 17 00:00:00 2001 From: puhuk Date: Fri, 31 Dec 2021 13:53:56 +0900 Subject: [PATCH 03/10] Update country211.py --- torchvision/datasets/country211.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/datasets/country211.py b/torchvision/datasets/country211.py index 7b10cc2539e..86eebade447 100644 --- a/torchvision/datasets/country211.py +++ b/torchvision/datasets/country211.py @@ -1,4 +1,3 @@ -import json from pathlib import Path from typing import Any, Tuple, Callable, Optional From c383093170298b98ddd05d3ee44d1a5615d78506 Mon Sep 17 00:00:00 2001 From: puhuk Date: Fri, 31 Dec 2021 14:05:29 +0900 Subject: [PATCH 04/10] Update country211.py --- torchvision/datasets/country211.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torchvision/datasets/country211.py b/torchvision/datasets/country211.py index 86eebade447..6ee6a6e2fbb 100644 --- a/torchvision/datasets/country211.py +++ b/torchvision/datasets/country211.py @@ -42,13 +42,10 @@ def __init__( if download: self._download() - self._labels = [] - self._image_files = [] - self.split_folder = self._base_folder / self._split - self.classes, class_to_idx = find_classes(self.split_folder) - self.samples = make_dataset(self.split_folder, class_to_idx, extensions, is_valid_file=None) + self.classes, class_to_idx = find_classes(str(self.split_folder)) + self.samples = make_dataset(str(self.split_folder), class_to_idx, extensions, is_valid_file=None) def __len__(self) -> int: return len(self._image_files) From 21ee73804f458df27af6e04ffef861ecfab7f3f3 Mon Sep 17 00:00:00 2001 From: puhuk Date: Wed, 5 Jan 2022 23:39:54 +0900 Subject: [PATCH 05/10] Code review reflected Reflect code review --- ClassName.txt | 3 ++ docs/source/datasets.rst | 1 + test/test_datasets.py | 44 ++++++++++++++++++++++++++++++ torchvision/datasets/country211.py | 8 ++++-- 4 files changed, 54 insertions(+), 2 deletions(-) create mode 100644 ClassName.txt diff --git a/ClassName.txt b/ClassName.txt new file mode 100644 index 00000000000..a06a60da80d --- /dev/null +++ b/ClassName.txt @@ -0,0 +1,3 @@ +/a/abbey +/a/airplane_cabin +/a/airport_terminal \ No newline at end of file diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 3a2872a6388..9864e41e736 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -73,6 +73,7 @@ You can also create your own datasets using the provided :ref:`base classes bool: def _download(self) -> None: if self._check_exists(): return - download_and_extract_archive(self._URL, download_root=self.root, md5=None) + download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5) From b7b3109274a73499bf48809a1dfee023a255434d Mon Sep 17 00:00:00 2001 From: puhuk Date: Thu, 6 Jan 2022 11:53:27 +0900 Subject: [PATCH 06/10] Update test_datasets.py --- test/test_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index ebfea02a98a..b6b4992fbb5 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2224,7 +2224,7 @@ def inject_fake_data(self, tmpdir: str, config): im_paths.extend( datasets_utils.create_image_folder( image_folder, - name = cls, + name=cls, file_name_fn=lambda idx: f"{cls}_{idx:05d}.jpg", num_examples=num_images_per_class, ) From 999d38614cc03f7e4c9addcfea803c0784c1580c Mon Sep 17 00:00:00 2001 From: puhuk Date: Sat, 8 Jan 2022 22:38:58 +0900 Subject: [PATCH 07/10] Update with review Update with review --- ClassName.txt | 3 -- docs/source/datasets.rst | 2 +- test/test_datasets.py | 46 +++++++++--------------------- torchvision/datasets/country211.py | 6 ++-- 4 files changed, 16 insertions(+), 41 deletions(-) delete mode 100644 ClassName.txt diff --git a/ClassName.txt b/ClassName.txt deleted file mode 100644 index a06a60da80d..00000000000 --- a/ClassName.txt +++ /dev/null @@ -1,3 +0,0 @@ -/a/abbey -/a/airplane_cabin -/a/airport_terminal \ No newline at end of file diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 9864e41e736..35e951c9a1a 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -38,6 +38,7 @@ You can also create your own datasets using the provided :ref:`base classes int: - return len(self._image_files) + return len(self.samples) def __getitem__(self, idx) -> Tuple[Any, Any]: image_file, label = self.samples[idx][0], self.samples[idx][1] @@ -67,9 +67,7 @@ def __getitem__(self, idx) -> Tuple[Any, Any]: return image, label def _check_exists(self) -> bool: - return all( - folder.exists() and folder.is_dir() for folder in (Path(self.root), self._base_folder, self._images_folder) - ) + return self._base_folder.exists() and self._base_folder.is_dir() def _download(self) -> None: if self._check_exists(): From 7144d5cd4f1b385f92c91a6534c9c6dafed7f487 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 10 Jan 2022 10:48:16 +0100 Subject: [PATCH 08/10] inherit from ImageFolder --- torchvision/datasets/country211.py | 39 ++++++++---------------------- 1 file changed, 10 insertions(+), 29 deletions(-) diff --git a/torchvision/datasets/country211.py b/torchvision/datasets/country211.py index ff5a5dcb2e0..67d67f02047 100644 --- a/torchvision/datasets/country211.py +++ b/torchvision/datasets/country211.py @@ -1,14 +1,11 @@ from pathlib import Path -from typing import Any, Tuple, Callable, Optional +from typing import Callable, Optional -import PIL.Image - -from .folder import find_classes, make_dataset +from .folder import ImageFolder from .utils import verify_str_arg, download_and_extract_archive -from .vision import VisionDataset -class Country211(VisionDataset): +class Country211(ImageFolder): """`The Country211 Data Set `_. filtered the YFCC100m dataset that have GPS coordinate corresponding to a ISO-3166 country code @@ -31,14 +28,15 @@ def __init__( self, root: str, split: str = "train", - download: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, - extensions: Tuple[str, ...] = ("jpg", "png"), + download: bool = True, ) -> None: - super().__init__(root, transform=transform, target_transform=target_transform) self._split = verify_str_arg(split, "split", ("train", "valid", "test")) - self._base_folder = Path(self.root) / "country211" + + root = Path(root).expanduser() + self.root = str(root) + self._base_folder = root / "country211" if download: self._download() @@ -46,25 +44,8 @@ def __init__( if not self._check_exists(): raise RuntimeError("Dataset not found. You can use download=True to download it") - self.split_folder = self._base_folder / self._split - - self.classes, class_to_idx = find_classes(str(self.split_folder)) - self.samples = make_dataset(str(self.split_folder), class_to_idx, extensions, is_valid_file=None) - - def __len__(self) -> int: - return len(self.samples) - - def __getitem__(self, idx) -> Tuple[Any, Any]: - image_file, label = self.samples[idx][0], self.samples[idx][1] - image = PIL.Image.open(image_file).convert("RGB") - - if self.transform: - image = self.transform(image) - - if self.target_transform: - label = self.target_transform(label) - - return image, label + super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform) + self.root = str(root) def _check_exists(self) -> bool: return self._base_folder.exists() and self._base_folder.is_dir() From dc185c90af8e5c6e9efa9b01d301a80c4414c618 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 10 Jan 2022 10:52:04 +0100 Subject: [PATCH 09/10] Update test/test_datasets.py --- test/test_datasets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index 750e2046ff3..b05431d5dab 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2465,7 +2465,6 @@ def _meta_to_split_and_classification_ann(self, meta, idx): class Country211TestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.Country211 - FEATURE_TYPES = (PIL.Image.Image, int) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "valid", "test")) From 9eb4f11821b4345d079f67c132d650bc533453da Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 10 Jan 2022 12:13:03 +0000 Subject: [PATCH 10/10] Docstring + minor test update --- test/test_datasets.py | 30 +++++++++++++++++------------- torchvision/datasets/country211.py | 12 +++++++----- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/test/test_datasets.py b/test/test_datasets.py index b05431d5dab..b0681c507d2 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2471,19 +2471,23 @@ class Country211TestCase(datasets_utils.ImageDatasetTestCase): def inject_fake_data(self, tmpdir: str, config): split_folder = pathlib.Path(tmpdir) / "country211" / config["split"] split_folder.mkdir(parents=True, exist_ok=True) - return sum( - [ - len( - datasets_utils.create_image_folder( - split_folder, - name=cls, - file_name_fn=lambda idx: f"{idx}.jpg", - num_examples=5, - ) - ) - for cls in ("AD", "BS", "GR") - ] - ) + + num_examples = { + "train": 3, + "valid": 4, + "test": 5, + }[config["split"]] + + classes = ("AD", "BS", "GR") + for cls in classes: + datasets_utils.create_image_folder( + split_folder, + name=cls, + file_name_fn=lambda idx: f"{idx}.jpg", + num_examples=num_examples, + ) + + return num_examples * len(classes) if __name__ == "__main__": diff --git a/torchvision/datasets/country211.py b/torchvision/datasets/country211.py index 67d67f02047..20b69bc729e 100644 --- a/torchvision/datasets/country211.py +++ b/torchvision/datasets/country211.py @@ -6,12 +6,12 @@ class Country211(ImageFolder): - """`The Country211 Data Set `_. - - filtered the YFCC100m dataset that have GPS coordinate corresponding to a ISO-3166 country code - and created a balanced dataset by sampling 150 train images, 50 validation images, - and 100 test images images for each country. + """`The Country211 Data Set `_ from OpenAI. + This dataset was built by filtering the images from the YFCC100m dataset + that have GPS coordinate corresponding to a ISO-3166 country code. The + dataset is balanced by sampling 150 train images, 50 validation images, and + 100 test images images for each country. Args: root (string): Root directory of the dataset. @@ -19,6 +19,8 @@ class Country211(ImageFolder): transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. + download (bool, optional): If True, downloads the dataset from the internet and puts it into + ``root/country211/``. If dataset is already downloaded, it is not downloaded again. """ _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"