Skip to content

Add support for PCAM dataset #5203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies:
- libpng
- jpeg
- ca-certificates
- h5py
- pip:
- future
- pillow >=5.3.0, !=8.3.*
Expand Down
1 change: 1 addition & 0 deletions .circleci/unittest/windows/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies:
- libpng
- jpeg
- ca-certificates
- h5py
- pip:
- future
- pillow >=5.3.0, !=8.3.*
Expand Down
1 change: 1 addition & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
MNIST
Omniglot
OxfordIIITPet
PCAM
PhotoTour
Places365
QMNIST
Expand Down
1 change: 1 addition & 0 deletions test/datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class LazyImporter:
"requests",
"scipy.io",
"scipy.sparse",
"h5py",
)

def __init__(self):
Expand Down
23 changes: 23 additions & 0 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2577,5 +2577,28 @@ def inject_fake_data(self, tmpdir: str, config):
return num_images_per_split[config["split"]]


class PCAMTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.PCAM

ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
REQUIRED_PACKAGES = ("h5py",)

def inject_fake_data(self, tmpdir: str, config):
base_folder = pathlib.Path(tmpdir) / "pcam"
base_folder.mkdir()

num_images = {"train": 2, "test": 3, "val": 4}[config["split"]]

images_file = datasets.PCAM._FILES[config["split"]]["images"][0]
with datasets_utils.lazy_importer.h5py.File(str(base_folder / images_file), "w") as f:
f["x"] = np.random.randint(0, 256, size=(num_images, 10, 10, 3), dtype=np.uint8)

targets_file = datasets.PCAM._FILES[config["split"]]["targets"][0]
with datasets_utils.lazy_importer.h5py.File(str(base_folder / targets_file), "w") as f:
f["y"] = np.random.randint(0, 2, size=(num_images, 1, 1, 1), dtype=np.uint8)

return num_images


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST, QMNIST
from .omniglot import Omniglot
from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM
from .phototour import PhotoTour
from .places365 import Places365
from .sbd import SBDataset
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/oxford_iiit_pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class OxfordIIITPet(VisionDataset):
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/dtd``. If
dataset is already downloaded, it is not downloaded again.
download (bool, optional): If True, downloads the dataset from the internet and puts it into
``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again.
Comment on lines +30 to +31
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

"""

_RESOURCES = (
Expand Down
130 changes: 130 additions & 0 deletions torchvision/datasets/pcam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import pathlib
from typing import Any, Callable, Optional, Tuple

from PIL import Image

from .utils import download_file_from_google_drive, _decompress, verify_str_arg
from .vision import VisionDataset


class PCAM(VisionDataset):
"""`PCAM Dataset <https://github.com/basveeling/pcam>`_.

The PatchCamelyon dataset is a binary classification dataset with 327,680
color images (96px x 96px), extracted from histopathologic scans of lymph node
sections. Each image is annotated with a binary label indicating presence of
metastatic tissue.

This dataset requires the ``h5py`` package which you can install with ``pip install h5py``.

Args:
root (string): Root directory of the dataset.
split (string, optional): The dataset split, supports ``"train"`` (default), ``"test"`` or ``"val"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
dataset is already downloaded, it is not downloaded again.
"""

_FILES = {
"train": {
"images": (
"camelyonpatch_level_2_split_train_x.h5", # Data file name
"1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2", # Google Drive ID
"1571f514728f59376b705fc836ff4b63", # md5 hash
),
Comment on lines +30 to +36
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not ecstatic about this big dict, but I needed everything in the same place to support a per-split download logic (i.e. only download the test data if we don't need train nor val).

"targets": (
"camelyonpatch_level_2_split_train_y.h5",
"1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG",
"35c2d7259d906cfc8143347bb8e05be7",
),
},
"test": {
"images": (
"camelyonpatch_level_2_split_test_x.h5",
"1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_",
"d5b63470df7cfa627aeec8b9dc0c066e",
),
"targets": (
"camelyonpatch_level_2_split_test_y.h5",
"17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP",
"2b85f58b927af9964a4c15b8f7e8f179",
),
},
"val": {
"images": (
"camelyonpatch_level_2_split_valid_x.h5",
"1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3",
"d8c2d60d490dbd479f8199bdfa0cf6ec",
),
"targets": (
"camelyonpatch_level_2_split_valid_y.h5",
"1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO",
"60a7035772fbdb7f34eb86d4420cf66a",
),
},
}

def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = True,
):
try:
import h5py # type: ignore[import]

self.h5py = h5py
except ImportError:
raise RuntimeError(
"h5py is not found. This dataset needs to have h5py installed: please run pip install h5py"
)

self._split = verify_str_arg(split, "split", ("train", "test", "val"))

super().__init__(root, transform=transform, target_transform=target_transform)
self._base_folder = pathlib.Path(self.root) / "pcam"

if download:
self._download()

if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")

def __len__(self) -> int:
images_file = self._FILES[self._split]["images"][0]
with self.h5py.File(self._base_folder / images_file) as images_data:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for here and below: opening a File does not load its data into memory, so the operation is very cheap and fast.

Similarly below accessing a single row in the file will not load the entire file, just a specific section of it.

I guess we could open the files and keep the handles in __init__, but I'm not sure it would be any faster, and we might not be able to ever close the handles properly.

return images_data["x"].shape[0]

def __getitem__(self, idx: int) -> Tuple[Any, Any]:
images_file = self._FILES[self._split]["images"][0]
with self.h5py.File(self._base_folder / images_file) as images_data:
image = Image.fromarray(images_data["x"][idx]).convert("RGB")

targets_file = self._FILES[self._split]["targets"][0]
with self.h5py.File(self._base_folder / targets_file) as targets_data:
target = int(targets_data["y"][idx, 0, 0, 0]) # shape is [num_images, 1, 1, 1]

if self.transform:
image = self.transform(image)
if self.target_transform:
target = self.target_transform(target)

return image, target

def _check_exists(self) -> bool:
images_file = self._FILES[self._split]["images"][0]
targets_file = self._FILES[self._split]["targets"][0]
return all(self._base_folder.joinpath(h5_file).exists() for h5_file in (images_file, targets_file))

def _download(self) -> None:
if self._check_exists():
return

for file_name, file_id, md5 in self._FILES[self._split].values():
archive_name = file_name + ".gz"
download_file_from_google_drive(file_id, str(self._base_folder), filename=archive_name, md5=md5)
_decompress(str(self._base_folder / archive_name))