Skip to content

Cleanups for FLAVA datasets #5164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,7 @@ ignore_missing_imports = True
[mypy-torchdata.*]

ignore_missing_imports = True

[mypy-h5py.*]

ignore_missing_imports = True
28 changes: 6 additions & 22 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2281,11 +2281,6 @@ def inject_fake_data(self, tmpdir: str, config):
class SUN397TestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.SUN397

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
split=("train", "test"),
partition=(1, 10, None),
)

def inject_fake_data(self, tmpdir: str, config):
data_dir = pathlib.Path(tmpdir) / "SUN397"
data_dir.mkdir()
Expand All @@ -2308,18 +2303,7 @@ def inject_fake_data(self, tmpdir: str, config):
with open(data_dir / "ClassName.txt", "w") as file:
file.writelines("\n".join(f"/{cls[0]}/{cls}" for cls in sampled_classes))

if config["partition"] is not None:
num_samples = max(len(im_paths) // (2 if config["split"] == "train" else 3), 1)

with open(data_dir / f"{config['split'].title()}ing_{config['partition']:02d}.txt", "w") as file:
file.writelines(
"\n".join(
f"/{f_path.relative_to(data_dir).as_posix()}"
for f_path in random.choices(im_paths, k=num_samples)
)
)
else:
num_samples = len(im_paths)
num_samples = len(im_paths)

return num_samples

Expand Down Expand Up @@ -2397,17 +2381,17 @@ class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.GTSRB
FEATURE_TYPES = (PIL.Image.Image, int)

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(train=(True, False))
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"))

def inject_fake_data(self, tmpdir: str, config):
root_folder = os.path.join(tmpdir, "GTSRB")
root_folder = os.path.join(tmpdir, "gtsrb")
os.makedirs(root_folder, exist_ok=True)

# Train data
train_folder = os.path.join(root_folder, "Training")
train_folder = os.path.join(root_folder, "GTSRB", "Training")
os.makedirs(train_folder, exist_ok=True)

num_examples = 3
num_examples = 3 if config["split"] == "train" else 4
classes = ("00000", "00042", "00012")
for class_idx in classes:
datasets_utils.create_image_folder(
Expand All @@ -2419,7 +2403,7 @@ def inject_fake_data(self, tmpdir: str, config):

total_number_of_examples = num_examples * len(classes)
# Test data
test_folder = os.path.join(root_folder, "Final_Test", "Images")
test_folder = os.path.join(root_folder, "GTSRB", "Final_Test", "Images")
os.makedirs(test_folder, exist_ok=True)

with open(os.path.join(root_folder, "GT-final_test.csv"), "w") as csv_file:
Expand Down
2 changes: 1 addition & 1 deletion torchvision/datasets/clevr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
download: bool = False,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
super().__init__(root, transform=transform, target_transform=target_transform)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/datasets/country211.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
download: bool = False,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "valid", "test"))

Expand Down
8 changes: 4 additions & 4 deletions torchvision/datasets/dtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ class DTD(VisionDataset):
The partition only changes which split each image belongs to. Thus, regardless of the selected
partition, combining all splits will result in all images.

download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
"""

_URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
Expand All @@ -37,9 +37,9 @@ def __init__(
root: str,
split: str = "train",
partition: int = 1,
download: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
if not isinstance(partition, int) and not (1 <= partition <= 10):
Expand Down
22 changes: 12 additions & 10 deletions torchvision/datasets/eurosat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any
from typing import Callable, Optional

from .folder import ImageFolder
from .utils import download_and_extract_archive
Expand All @@ -10,23 +10,21 @@ class EuroSAT(ImageFolder):

Args:
root (string): Root directory of dataset where ``root/eurosat`` exists.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
"""

url = "https://madm.dfki.de/files/sentinel/EuroSAT.zip"
md5 = "c8fa014336c82ac7804f0398fcb19387"

def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
**kwargs: Any,
) -> None:
self.root = os.path.expanduser(root)
self._base_folder = os.path.join(self.root, "eurosat")
Expand All @@ -38,7 +36,7 @@ def __init__(
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

super().__init__(self._data_folder, **kwargs)
super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
self.root = os.path.expanduser(root)

def __len__(self) -> int:
Expand All @@ -53,4 +51,8 @@ def download(self) -> None:
return

os.makedirs(self._base_folder, exist_ok=True)
download_and_extract_archive(self.url, download_root=self._base_folder, md5=self.md5)
download_and_extract_archive(
"https://madm.dfki.de/files/sentinel/EuroSAT.zip",
download_root=self._base_folder,
md5="c8fa014336c82ac7804f0398fcb19387",
)
8 changes: 4 additions & 4 deletions torchvision/datasets/fgvc_aircraft.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ class FGVCAircraft(VisionDataset):
root (string): Root directory of the FGVC Aircraft dataset.
split (string, optional): The dataset split, supports ``train``, ``val``,
``trainval`` and ``test``.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
annotation_level (str, optional): The annotation level, supports ``variant``,
``family`` and ``manufacturer``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

_URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
Expand All @@ -43,10 +43,10 @@ def __init__(
self,
root: str,
split: str = "trainval",
download: bool = False,
annotation_level: str = "variant",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
Expand Down
8 changes: 4 additions & 4 deletions torchvision/datasets/flowers102.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ class Flowers102(VisionDataset):
Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a
transformed version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""

_download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
Expand All @@ -44,9 +44,9 @@ def __init__(
self,
root: str,
split: str = "train",
download: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "val", "test"))
Expand Down
5 changes: 4 additions & 1 deletion torchvision/datasets/food101.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class Food101(VisionDataset):
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again. Default is False.
"""

_URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
Expand All @@ -33,9 +36,9 @@ def __init__(
self,
root: str,
split: str = "train",
download: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self._split = verify_str_arg(split, "split", ("train", "test"))
Expand Down
63 changes: 30 additions & 33 deletions torchvision/datasets/gtsrb.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import csv
import os
import pathlib
from typing import Any, Callable, Optional, Tuple

import PIL

from .folder import make_dataset
from .utils import download_and_extract_archive
from .utils import download_and_extract_archive, verify_str_arg
from .vision import VisionDataset


Expand All @@ -14,8 +14,7 @@ class GTSRB(VisionDataset):

Args:
root (string): Root directory of the dataset.
train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
Expand All @@ -24,49 +23,35 @@ class GTSRB(VisionDataset):
downloaded again.
"""

# Ground Truth for the test set
_gt_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_GT.zip"
_gt_csv = "GT-final_test.csv"
_gt_md5 = "fe31e9c9270bbcd7b84b7f21a9d9d9e5"

# URLs for the test and train set
_urls = (
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip",
"https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip",
)

_md5s = ("c7e4e6327067d32654124b0fe9e82185", "513f3c79a4c5141765e10e952eaa2478")

def __init__(
self,
root: str,
train: bool = True,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:

super().__init__(root, transform=transform, target_transform=target_transform)

self.root = os.path.expanduser(root)

self.train = train

self._base_folder = os.path.join(self.root, type(self).__name__)
self._target_folder = os.path.join(self._base_folder, "Training" if self.train else "Final_Test/Images")
self._split = verify_str_arg(split, "split", ("train", "test"))
self._base_folder = pathlib.Path(root) / "gtsrb"
self._target_folder = (
self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
)

if download:
self.download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

if train:
samples = make_dataset(self._target_folder, extensions=(".ppm",))
if self._split == "train":
samples = make_dataset(str(self._target_folder), extensions=(".ppm",))
else:
with open(os.path.join(self._base_folder, self._gt_csv)) as csv_file:
with open(self._base_folder / "GT-final_test.csv") as csv_file:
samples = [
(os.path.join(self._target_folder, row["Filename"]), int(row["ClassId"]))
(str(self._target_folder / row["Filename"]), int(row["ClassId"]))
for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
]

Expand All @@ -91,16 +76,28 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
return sample, target

def _check_exists(self) -> bool:
return os.path.exists(self._target_folder) and os.path.isdir(self._target_folder)
return self._target_folder.is_dir()

def download(self) -> None:
if self._check_exists():
return

download_and_extract_archive(self._urls[self.train], download_root=self.root, md5=self._md5s[self.train])
base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"

if not self.train:
# Download Ground Truth for the test set
if self._split == "train":
download_and_extract_archive(
f"{base_url}GTSRB-Training_fixed.zip",
download_root=str(self._base_folder),
md5="513f3c79a4c5141765e10e952eaa2478",
)
else:
download_and_extract_archive(
f"{base_url}GTSRB_Final_Test_Images.zip",
download_root=str(self._base_folder),
md5="c7e4e6327067d32654124b0fe9e82185",
)
download_and_extract_archive(
self._gt_url, download_root=self.root, extract_root=self._base_folder, md5=self._gt_md5
f"{base_url}GTSRB_Final_Test_GT.zip",
download_root=str(self._base_folder),
md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
)
2 changes: 1 addition & 1 deletion torchvision/datasets/oxford_iiit_pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
download: bool = False,
):
self._split = verify_str_arg(split, "split", ("trainval", "test"))
if isinstance(target_types, str):
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/pcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def __init__(
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
download: bool = False,
):
try:
import h5py # type: ignore[import]
import h5py

self.h5py = h5py
except ImportError:
Expand Down
Loading