diff --git a/tests/ignite/__init__.py b/tests/ignite/__init__.py index 8f84e2e74b99..10b666832d0e 100644 --- a/tests/ignite/__init__.py +++ b/tests/ignite/__init__.py @@ -1,11 +1,12 @@ import torch - -def cpu_and_maybe_cuda(): - return ("cpu",) + (("cuda",) if torch.cuda.is_available() else ()) +from ignite.distributed.comp_models.base import _torch_version_gt_112 def is_mps_available_and_functional(): + if not _torch_version_gt_112: + return False + if not torch.backends.mps.is_available(): return False try: diff --git a/tests/ignite/conftest.py b/tests/ignite/conftest.py index d5546a75bae5..0add0340bd78 100644 --- a/tests/ignite/conftest.py +++ b/tests/ignite/conftest.py @@ -13,6 +13,7 @@ import torch.distributed as dist import ignite.distributed as idist +from tests.ignite import is_mps_available_and_functional def pytest_addoption(parser): @@ -72,6 +73,9 @@ def term_handler(): params=[ "cpu", pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no CUDA support")), + pytest.param( + "mps", marks=pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if no MPS support") + ), ] ) def available_device(request): diff --git a/tests/ignite/distributed/test_auto.py b/tests/ignite/distributed/test_auto.py index 2ecc3404c907..6893369aad6e 100644 --- a/tests/ignite/distributed/test_auto.py +++ b/tests/ignite/distributed/test_auto.py @@ -12,7 +12,6 @@ import ignite.distributed as idist from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler -from ignite.distributed.comp_models.base import _torch_version_gt_112 from tests.ignite import is_mps_available_and_functional @@ -181,10 +180,7 @@ def _test_auto_model_optimizer(ws, device): assert optimizer.backward_passes_per_step == backward_passes_per_step -@pytest.mark.skipif( - (not _torch_version_gt_112) or (torch.backends.mps.is_available() and not is_mps_available_and_functional()), - reason="Skip if MPS not functional", -) +@pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if MPS not functional") def test_auto_methods_no_dist(): _test_auto_dataloader(1, 1, batch_size=1) _test_auto_dataloader(1, 1, batch_size=10, num_workers=2) diff --git a/tests/ignite/distributed/test_launcher.py b/tests/ignite/distributed/test_launcher.py index eac7ffe2e06c..6dcbfb198d82 100644 --- a/tests/ignite/distributed/test_launcher.py +++ b/tests/ignite/distributed/test_launcher.py @@ -8,7 +8,6 @@ from packaging.version import Version import ignite.distributed as idist -from ignite.distributed.comp_models.base import _torch_version_gt_112 from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support from tests.ignite import is_mps_available_and_functional @@ -56,10 +55,7 @@ def execute(cmd, env=None): return str(process.stdout.read()) + str(process.stderr.read()) -@pytest.mark.skipif( - (not _torch_version_gt_112) or (torch.backends.mps.is_available() and not is_mps_available_and_functional()), - reason="Skip if MPS not functional", -) +@pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if MPS not functional") def test_check_idist_parallel_no_dist(exec_filepath): cmd = [sys.executable, "-u", exec_filepath] out = execute(cmd) diff --git a/tests/ignite/engine/test_create_supervised.py b/tests/ignite/engine/test_create_supervised.py index 4f07c95929e0..ba42baddddae 100644 --- a/tests/ignite/engine/test_create_supervised.py +++ b/tests/ignite/engine/test_create_supervised.py @@ -12,7 +12,6 @@ from torch.optim import SGD import ignite.distributed as idist -from ignite.distributed.comp_models.base import _torch_version_gt_112 from ignite.engine import ( _check_arg, create_supervised_evaluator, @@ -488,7 +487,7 @@ def test_create_supervised_trainer_on_cuda(): _test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device) -@pytest.mark.skipif(not (_torch_version_gt_112 and is_mps_available_and_functional()), reason="Skip if no MPS") +@pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if no MPS") def test_create_supervised_trainer_on_mps(): model_device = trainer_device = "mps" _test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device) @@ -669,14 +668,14 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu(): _test_mocked_supervised_evaluator(evaluator_device="cuda") -@pytest.mark.skipif(not (_torch_version_gt_112 and is_mps_available_and_functional()), reason="Skip if no MPS") +@pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if no MPS") def test_create_supervised_evaluator_on_mps(): model_device = evaluator_device = "mps" _test_create_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device) _test_mocked_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device) -@pytest.mark.skipif(not (_torch_version_gt_112 and is_mps_available_and_functional()), reason="Skip if no MPS") +@pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if no MPS") def test_create_supervised_evaluator_on_mps_with_model_on_cpu(): _test_create_supervised_evaluator(evaluator_device="mps") _test_mocked_supervised_evaluator(evaluator_device="mps") diff --git a/tests/ignite/metrics/conftest.py b/tests/ignite/metrics/conftest.py new file mode 100644 index 000000000000..a86dd4dce92f --- /dev/null +++ b/tests/ignite/metrics/conftest.py @@ -0,0 +1,83 @@ +import pytest +import torch + + +@pytest.fixture(params=range(14)) +def test_data_binary(request): + return [ + # Binary accuracy on input of shape (N, 1) or (N, ) + (torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)), 1), + (torch.randint(0, 2, size=(10, 1)), torch.randint(0, 2, size=(10, 1)), 1), + # updated batches + (torch.randint(0, 2, size=(50,)), torch.randint(0, 2, size=(50,)), 16), + (torch.randint(0, 2, size=(50, 1)), torch.randint(0, 2, size=(50, 1)), 16), + # Binary accuracy on input of shape (N, L) + (torch.randint(0, 2, size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), + (torch.randint(0, 2, size=(10, 1, 5)), torch.randint(0, 2, size=(10, 1, 5)), 1), + # updated batches + (torch.randint(0, 2, size=(50, 5)), torch.randint(0, 2, size=(50, 5)), 16), + (torch.randint(0, 2, size=(50, 1, 5)), torch.randint(0, 2, size=(50, 1, 5)), 16), + # Binary accuracy on input of shape (N, H, W) + (torch.randint(0, 2, size=(10, 12, 10)), torch.randint(0, 2, size=(10, 12, 10)), 1), + (torch.randint(0, 2, size=(10, 1, 12, 10)), torch.randint(0, 2, size=(10, 1, 12, 10)), 1), + # updated batches + (torch.randint(0, 2, size=(50, 12, 10)), torch.randint(0, 2, size=(50, 12, 10)), 16), + (torch.randint(0, 2, size=(50, 1, 12, 10)), torch.randint(0, 2, size=(50, 1, 12, 10)), 16), + # Corner case with all zeros predictions + (torch.zeros(size=(10,), dtype=torch.long), torch.randint(0, 2, size=(10,)), 1), + (torch.zeros(size=(10, 1), dtype=torch.long), torch.randint(0, 2, size=(10, 1)), 1), + ][request.param] + + +@pytest.fixture(params=range(14)) +def test_data_multiclass(request): + return [ + # Multiclass input data of shape (N, ) and (N, C) + (torch.rand(10, 6), torch.randint(0, 6, size=(10,)), 1), + (torch.rand(10, 4), torch.randint(0, 4, size=(10,)), 1), + # updated batches + (torch.rand(50, 6), torch.randint(0, 6, size=(50,)), 16), + (torch.rand(50, 4), torch.randint(0, 4, size=(50,)), 16), + # Multiclass input data of shape (N, L) and (N, C, L) + (torch.rand(10, 5, 8), torch.randint(0, 5, size=(10, 8)), 1), + (torch.rand(10, 8, 12), torch.randint(0, 8, size=(10, 12)), 1), + # updated batches + (torch.rand(50, 5, 8), torch.randint(0, 5, size=(50, 8)), 16), + (torch.rand(50, 8, 12), torch.randint(0, 8, size=(50, 12)), 16), + # Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...) + (torch.rand(10, 5, 18, 16), torch.randint(0, 5, size=(10, 18, 16)), 1), + (torch.rand(10, 7, 20, 12), torch.randint(0, 7, size=(10, 20, 12)), 1), + # updated batches + (torch.rand(50, 5, 18, 16), torch.randint(0, 5, size=(50, 18, 16)), 16), + (torch.rand(50, 7, 20, 12), torch.randint(0, 7, size=(50, 20, 12)), 16), + # Corner case with all zeros predictions + (torch.zeros(size=(10, 6)), torch.randint(0, 6, size=(10,)), 1), + (torch.zeros(size=(10, 4)), torch.randint(0, 4, size=(10,)), 1), + ][request.param] + + +@pytest.fixture(params=range(14)) +def test_data_multilabel(request): + return [ + # Multilabel input data of shape (N, C) + (torch.randint(0, 2, size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), + (torch.randint(0, 2, size=(10, 4)), torch.randint(0, 2, size=(10, 4)), 1), + # updated batches + (torch.randint(0, 2, size=(50, 5)), torch.randint(0, 2, size=(50, 5)), 16), + (torch.randint(0, 2, size=(50, 4)), torch.randint(0, 2, size=(50, 4)), 16), + # Multilabel input data of shape (N, C, L) + (torch.randint(0, 2, size=(10, 5, 10)), torch.randint(0, 2, size=(10, 5, 10)), 1), + (torch.randint(0, 2, size=(10, 4, 10)), torch.randint(0, 2, size=(10, 4, 10)), 1), + # updated batches + (torch.randint(0, 2, size=(50, 5, 10)), torch.randint(0, 2, size=(50, 5, 10)), 16), + (torch.randint(0, 2, size=(50, 4, 10)), torch.randint(0, 2, size=(50, 4, 10)), 16), + # Multilabel input data of shape (N, C, H, W) + (torch.randint(0, 2, size=(10, 5, 18, 16)), torch.randint(0, 2, size=(10, 5, 18, 16)), 1), + (torch.randint(0, 2, size=(10, 4, 20, 23)), torch.randint(0, 2, size=(10, 4, 20, 23)), 1), + # updated batches + (torch.randint(0, 2, size=(50, 5, 18, 16)), torch.randint(0, 2, size=(50, 5, 18, 16)), 16), + (torch.randint(0, 2, size=(50, 4, 20, 23)), torch.randint(0, 2, size=(50, 4, 20, 23)), 16), + # Corner case with all zeros predictions + (torch.zeros(size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), + (torch.zeros(size=(10, 4)), torch.randint(0, 2, size=(10, 4)), 1), + ][request.param] diff --git a/tests/ignite/metrics/test_accuracy.py b/tests/ignite/metrics/test_accuracy.py index 35631b2b47e7..2f7a3e890e89 100644 --- a/tests/ignite/metrics/test_accuracy.py +++ b/tests/ignite/metrics/test_accuracy.py @@ -1,10 +1,8 @@ -import os from typing import Callable, Union from unittest.mock import MagicMock import pytest import torch -from packaging.version import Version from sklearn.metrics import accuracy_score import ignite.distributed as idist @@ -65,33 +63,9 @@ def test_binary_wrong_inputs(): acc.update((torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10, 5, 6)).long())) -@pytest.fixture(params=range(12)) -def test_data_binary(request): - return [ - # Binary accuracy on input of shape (N, 1) or (N, ) - (torch.randint(0, 2, size=(10,)).long(), torch.randint(0, 2, size=(10,)).long(), 1), - (torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 1)).long(), 1), - # updated batches - (torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16), - (torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16), - # Binary accuracy on input of shape (N, L) - (torch.randint(0, 2, size=(10, 5)).long(), torch.randint(0, 2, size=(10, 5)).long(), 1), - (torch.randint(0, 2, size=(10, 8)).long(), torch.randint(0, 2, size=(10, 8)).long(), 1), - # updated batches - (torch.randint(0, 2, size=(50, 5)).long(), torch.randint(0, 2, size=(50, 5)).long(), 16), - (torch.randint(0, 2, size=(50, 8)).long(), torch.randint(0, 2, size=(50, 8)).long(), 16), - # Binary accuracy on input of shape (N, H, W, ...) - (torch.randint(0, 2, size=(4, 1, 12, 10)).long(), torch.randint(0, 2, size=(4, 1, 12, 10)).long(), 1), - (torch.randint(0, 2, size=(15, 1, 20, 10)).long(), torch.randint(0, 2, size=(15, 1, 20, 10)).long(), 1), - # updated batches - (torch.randint(0, 2, size=(50, 1, 12, 10)).long(), torch.randint(0, 2, size=(50, 1, 12, 10)).long(), 16), - (torch.randint(0, 2, size=(50, 1, 20, 10)).long(), torch.randint(0, 2, size=(50, 1, 20, 10)).long(), 16), - ][request.param] - - -@pytest.mark.parametrize("n_times", range(5)) -def test_binary_input(n_times, test_data_binary): - acc = Accuracy() +@pytest.mark.parametrize("n_times", range(3)) +def test_binary_input(n_times, available_device, test_data_binary): + acc = Accuracy(device=available_device) y_pred, y, batch_size = test_data_binary acc.reset() @@ -127,30 +101,9 @@ def test_multiclass_wrong_inputs(): acc.update((torch.rand(10), torch.randint(0, 5, size=(10, 5, 6)).long())) -@pytest.fixture(params=range(11)) -def test_data_multiclass(request): - return [ - # Multiclass input data of shape (N, ) and (N, C) - (torch.rand(10, 4), torch.randint(0, 4, size=(10,)).long(), 1), - (torch.rand(10, 10, 1), torch.randint(0, 18, size=(10, 1)).long(), 1), - (torch.rand(10, 18), torch.randint(0, 18, size=(10,)).long(), 1), - (torch.rand(4, 10), torch.randint(0, 10, size=(4,)).long(), 1), - # 2-classes - (torch.rand(4, 2), torch.randint(0, 2, size=(4,)).long(), 1), - (torch.rand(100, 5), torch.randint(0, 5, size=(100,)).long(), 16), - # Multiclass input data of shape (N, L) and (N, C, L) - (torch.rand(10, 4, 5), torch.randint(0, 4, size=(10, 5)).long(), 1), - (torch.rand(4, 10, 5), torch.randint(0, 10, size=(4, 5)).long(), 1), - (torch.rand(100, 9, 7), torch.randint(0, 9, size=(100, 7)).long(), 16), - # Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...) - (torch.rand(4, 5, 12, 10), torch.randint(0, 5, size=(4, 12, 10)).long(), 1), - (torch.rand(100, 3, 8, 8), torch.randint(0, 3, size=(100, 8, 8)).long(), 16), - ][request.param] - - -@pytest.mark.parametrize("n_times", range(5)) -def test_multiclass_input(n_times, test_data_multiclass): - acc = Accuracy() +@pytest.mark.parametrize("n_times", range(3)) +def test_multiclass_input(n_times, available_device, test_data_multiclass): + acc = Accuracy(device=available_device) y_pred, y, batch_size = test_data_multiclass acc.reset() @@ -199,33 +152,9 @@ def test_multilabel_wrong_inputs(): acc.update((torch.randint(0, 2, size=(10, 1)), torch.randint(0, 2, size=(10, 1)).long())) -@pytest.fixture(params=range(12)) -def test_data_multilabel(request): - return [ - # Multilabel input data of shape (N, C) and (N, C) - (torch.randint(0, 2, size=(10, 4)).long(), torch.randint(0, 2, size=(10, 4)).long(), 1), - (torch.randint(0, 2, size=(10, 7)).long(), torch.randint(0, 2, size=(10, 7)).long(), 1), - # updated batches - (torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16), - (torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16), - # Multilabel input data of shape (N, H, W) - (torch.randint(0, 2, size=(10, 5, 10)).long(), torch.randint(0, 2, size=(10, 5, 10)).long(), 1), - (torch.randint(0, 2, size=(10, 4, 10)).long(), torch.randint(0, 2, size=(10, 4, 10)).long(), 1), - # updated batches - (torch.randint(0, 2, size=(50, 5, 10)).long(), torch.randint(0, 2, size=(50, 5, 10)).long(), 16), - (torch.randint(0, 2, size=(50, 4, 10)).long(), torch.randint(0, 2, size=(50, 4, 10)).long(), 16), - # Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...) - (torch.randint(0, 2, size=(4, 5, 12, 10)).long(), torch.randint(0, 2, size=(4, 5, 12, 10)).long(), 1), - (torch.randint(0, 2, size=(4, 10, 12, 8)).long(), torch.randint(0, 2, size=(4, 10, 12, 8)).long(), 1), - # updated batches - (torch.randint(0, 2, size=(50, 5, 12, 10)).long(), torch.randint(0, 2, size=(50, 5, 12, 10)).long(), 16), - (torch.randint(0, 2, size=(50, 10, 12, 8)).long(), torch.randint(0, 2, size=(50, 10, 12, 8)).long(), 16), - ][request.param] - - -@pytest.mark.parametrize("n_times", range(5)) -def test_multilabel_input(n_times, test_data_multilabel): - acc = Accuracy(is_multilabel=True) +@pytest.mark.parametrize("n_times", range(3)) +def test_multilabel_input(n_times, available_device, test_data_multilabel): + acc = Accuracy(is_multilabel=True, device=available_device) y_pred, y, batch_size = test_data_multilabel if batch_size > 1: @@ -260,386 +189,281 @@ def test_incorrect_type(): acc.update((y_pred, y)) -def _test_distrib_multilabel_input_NHW(device): - # Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...) - - rank = idist.get_rank() - - def _test(metric_device): - metric_device = torch.device(metric_device) - acc = Accuracy(is_multilabel=True, device=metric_device) - +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_multilabel_input_NHW(self): + # Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...) + rank = idist.get_rank() torch.manual_seed(10 + rank) - y_pred = torch.randint(0, 2, size=(4, 5, 8, 10), device=device).long() - y = torch.randint(0, 2, size=(4, 5, 8, 10), device=device).long() - acc.update((y_pred, y)) - - assert ( - acc._num_correct.device == metric_device - ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" - n = acc._num_examples - assert n == y.numel() / y.size(dim=1) - - # gather y_pred, y - y_pred = idist.all_gather(y_pred) - y = idist.all_gather(y) - - np_y_pred = to_numpy_multilabel(y_pred.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) - np_y = to_numpy_multilabel(y.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) - assert acc._type == "multilabel" - res = acc.compute() - assert n == acc._num_examples - assert isinstance(res, float) - assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) - - acc.reset() + metric_devices = [torch.device("cpu")] + device = idist.device() + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + acc = Accuracy(is_multilabel=True, device=metric_device) + + y_pred = torch.randint(0, 2, size=(4, 5, 8, 10), device=device).long() + y = torch.randint(0, 2, size=(4, 5, 8, 10), device=device).long() + acc.update((y_pred, y)) + + assert ( + acc._num_correct.device == metric_device + ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" + + n = acc._num_examples + assert n == y.numel() / y.size(dim=1) + + # gather y_pred, y + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + + np_y_pred = to_numpy_multilabel(y_pred.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) + np_y = to_numpy_multilabel(y.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) + assert acc._type == "multilabel" + res = acc.compute() + assert n == acc._num_examples + assert isinstance(res, float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) + + acc.reset() + torch.manual_seed(10 + rank) + y_pred = torch.randint(0, 2, size=(4, 7, 10, 8), device=device).long() + y = torch.randint(0, 2, size=(4, 7, 10, 8), device=device).long() + acc.update((y_pred, y)) + + assert ( + acc._num_correct.device == metric_device + ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" + + n = acc._num_examples + assert n == y.numel() / y.size(dim=1) + + # gather y_pred, y + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + + np_y_pred = to_numpy_multilabel(y_pred.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) + np_y = to_numpy_multilabel(y.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) + + assert acc._type == "multilabel" + res = acc.compute() + assert n == acc._num_examples + assert isinstance(res, float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) + # check that result is not changed + res = acc.compute() + assert n == acc._num_examples + assert isinstance(res, float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) + + # Batched Updates + acc.reset() + torch.manual_seed(10 + rank) + y_pred = torch.randint(0, 2, size=(80, 5, 8, 10), device=device).long() + y = torch.randint(0, 2, size=(80, 5, 8, 10), device=device).long() + + batch_size = 16 + n_iters = y.shape[0] // batch_size + 1 + + for i in range(n_iters): + idx = i * batch_size + acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + + assert ( + acc._num_correct.device == metric_device + ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" + + n = acc._num_examples + assert n == y.numel() / y.size(dim=1) + + # gather y_pred, y + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + + np_y_pred = to_numpy_multilabel(y_pred.cpu()) # (N, C, L, ...) -> (N * L * ..., C) + np_y = to_numpy_multilabel(y.cpu()) # (N, C, L, ...) -> (N * L ..., C) + + assert acc._type == "multilabel" + res = acc.compute() + assert n == acc._num_examples + assert isinstance(res, float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) + + @pytest.mark.parametrize("n_epochs", [1, 2]) + def test_integration_multiclass(self, n_epochs): + rank = idist.get_rank() torch.manual_seed(10 + rank) - y_pred = torch.randint(0, 2, size=(4, 7, 10, 8), device=device).long() - y = torch.randint(0, 2, size=(4, 7, 10, 8), device=device).long() - acc.update((y_pred, y)) - assert ( - acc._num_correct.device == metric_device - ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" + n_iters = 80 + batch_size = 16 + n_classes = 10 - n = acc._num_examples - assert n == y.numel() / y.size(dim=1) + metric_devices = [torch.device("cpu")] + device = idist.device() + if device.type != "xla": + metric_devices.append(device) - # gather y_pred, y - y_pred = idist.all_gather(y_pred) - y = idist.all_gather(y) + for metric_device in metric_devices: + y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device) + y_preds = torch.rand(n_iters * batch_size, n_classes).to(device) - np_y_pred = to_numpy_multilabel(y_pred.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) - np_y = to_numpy_multilabel(y.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) + def update(engine, i): + return ( + y_preds[i * batch_size : (i + 1) * batch_size, :], + y_true[i * batch_size : (i + 1) * batch_size], + ) - assert acc._type == "multilabel" - res = acc.compute() - assert n == acc._num_examples - assert isinstance(res, float) - assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) - # check that result is not changed - res = acc.compute() - assert n == acc._num_examples - assert isinstance(res, float) - assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) + engine = Engine(update) - # Batched Updates - acc.reset() - torch.manual_seed(10 + rank) - y_pred = torch.randint(0, 2, size=(80, 5, 8, 10), device=device).long() - y = torch.randint(0, 2, size=(80, 5, 8, 10), device=device).long() + acc = Accuracy(device=metric_device) + acc.attach(engine, "acc") - batch_size = 16 - n_iters = y.shape[0] // batch_size + 1 + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) - for i in range(n_iters): - idx = i * batch_size - acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + y_true = idist.all_gather(y_true) + y_preds = idist.all_gather(y_preds) - assert ( - acc._num_correct.device == metric_device - ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" + assert ( + acc._num_correct.device == metric_device + ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" - n = acc._num_examples - assert n == y.numel() / y.size(dim=1) + assert "acc" in engine.state.metrics + res = engine.state.metrics["acc"] + if isinstance(res, torch.Tensor): + res = res.cpu().numpy() - # gather y_pred, y - y_pred = idist.all_gather(y_pred) - y = idist.all_gather(y) + true_res = accuracy_score(y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy()) - np_y_pred = to_numpy_multilabel(y_pred.cpu()) # (N, C, L, ...) -> (N * L * ..., C) - np_y = to_numpy_multilabel(y.cpu()) # (N, C, L, ...) -> (N * L ..., C) - - assert acc._type == "multilabel" - res = acc.compute() - assert n == acc._num_examples - assert isinstance(res, float) - assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) - - # check multiple random inputs as random exact occurencies are rare - for _ in range(3): - _test("cpu") - if device.type != "xla": - _test(idist.device()) + assert pytest.approx(res) == true_res + metric_state = acc.state_dict() + saved__num_correct = acc._num_correct + saved__num_examples = acc._num_examples + acc.reset() + acc.load_state_dict(metric_state) + assert acc._num_examples == saved__num_examples + assert (acc._num_correct == saved__num_correct).all() -def _test_distrib_integration_multiclass(device): - rank = idist.get_rank() + @pytest.mark.parametrize("n_epochs", [1, 2]) + def test_integration_multilabel(self, n_epochs): + rank = idist.get_rank() + torch.manual_seed(12 + rank) - def _test(n_epochs, metric_device): - metric_device = torch.device(metric_device) n_iters = 80 batch_size = 16 n_classes = 10 - torch.manual_seed(12 + rank) - - y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device) - y_preds = torch.rand(n_iters * batch_size, n_classes).to(device) - - def update(engine, i): - return ( - y_preds[i * batch_size : (i + 1) * batch_size, :], - y_true[i * batch_size : (i + 1) * batch_size], - ) + metric_devices = [torch.device("cpu")] + device = idist.device() + if device.type != "xla": + metric_devices.append(device) - engine = Engine(update) + for metric_device in metric_devices: + metric_device = torch.device(metric_device) - acc = Accuracy(device=metric_device) - acc.attach(engine, "acc") + y_true = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 8, 10)).to(device) + y_preds = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 8, 10)).to(device) - data = list(range(n_iters)) - engine.run(data=data, max_epochs=n_epochs) + def update(engine, i): + return ( + y_preds[i * batch_size : (i + 1) * batch_size, ...], + y_true[i * batch_size : (i + 1) * batch_size, ...], + ) - y_true = idist.all_gather(y_true) - y_preds = idist.all_gather(y_preds) + engine = Engine(update) - assert ( - acc._num_correct.device == metric_device - ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" + acc = Accuracy(is_multilabel=True, device=metric_device) + acc.attach(engine, "acc") - assert "acc" in engine.state.metrics - res = engine.state.metrics["acc"] - if isinstance(res, torch.Tensor): - res = res.cpu().numpy() + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) - true_res = accuracy_score(y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy()) + y_true = idist.all_gather(y_true) + y_preds = idist.all_gather(y_preds) - assert pytest.approx(res) == true_res + assert ( + acc._num_correct.device == metric_device + ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" - metric_state = acc.state_dict() - saved__num_correct = acc._num_correct - saved__num_examples = acc._num_examples - acc.reset() - acc.load_state_dict(metric_state) - assert acc._num_examples == saved__num_examples - assert (acc._num_correct == saved__num_correct).all() + assert "acc" in engine.state.metrics + res = engine.state.metrics["acc"] + if isinstance(res, torch.Tensor): + res = res.cpu().numpy() - metric_devices = ["cpu"] - if device.type != "xla": - metric_devices.append(idist.device()) - for metric_device in metric_devices: - for _ in range(2): - _test(n_epochs=1, metric_device=metric_device) - _test(n_epochs=2, metric_device=metric_device) + true_res = accuracy_score(to_numpy_multilabel(y_true), to_numpy_multilabel(y_preds)) + assert pytest.approx(res) == true_res -def _test_distrib_integration_multilabel(device): - rank = idist.get_rank() + def test_accumulator_device(self): + metric_devices = [torch.device("cpu")] + device = idist.device() + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + acc = Accuracy(device=metric_device) + assert acc._device == metric_device + assert ( + acc._num_correct.device == metric_device + ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" + + y_pred = torch.randint(0, 2, size=(10,), device=device, dtype=torch.long) + y = torch.randint(0, 2, size=(10,), device=device, dtype=torch.long) + acc.update((y_pred, y)) + + assert ( + acc._num_correct.device == metric_device + ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" + + @pytest.mark.parametrize("n_epochs", [1, 2]) + def test_integration_list_of_tensors_or_numbers(self, n_epochs): + rank = idist.get_rank() + torch.manual_seed(12 + rank) - def _test(n_epochs, metric_device): - metric_device = torch.device(metric_device) n_iters = 80 batch_size = 16 n_classes = 10 - torch.manual_seed(12 + rank) - - y_true = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 8, 10)).to(device) - y_preds = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 8, 10)).to(device) - - def update(engine, i): - return ( - y_preds[i * batch_size : (i + 1) * batch_size, ...], - y_true[i * batch_size : (i + 1) * batch_size, ...], - ) - - engine = Engine(update) - - acc = Accuracy(is_multilabel=True, device=metric_device) - acc.attach(engine, "acc") - - data = list(range(n_iters)) - engine.run(data=data, max_epochs=n_epochs) - - y_true = idist.all_gather(y_true) - y_preds = idist.all_gather(y_preds) - - assert ( - acc._num_correct.device == metric_device - ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" - - assert "acc" in engine.state.metrics - res = engine.state.metrics["acc"] - if isinstance(res, torch.Tensor): - res = res.cpu().numpy() - - true_res = accuracy_score(to_numpy_multilabel(y_true), to_numpy_multilabel(y_preds)) - - assert pytest.approx(res) == true_res - - metric_devices = ["cpu"] - if device.type != "xla": - metric_devices.append(idist.device()) - for metric_device in metric_devices: - for _ in range(2): - _test(n_epochs=1, metric_device=metric_device) - _test(n_epochs=2, metric_device=metric_device) - - -def _test_distrib_accumulator_device(device): - metric_devices = [torch.device("cpu")] - if device.type != "xla": - metric_devices.append(idist.device()) - for metric_device in metric_devices: - acc = Accuracy(device=metric_device) - assert acc._device == metric_device - assert ( - acc._num_correct.device == metric_device - ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" + metric_devices = [torch.device("cpu")] + device = idist.device() + if device.type != "xla": + metric_devices.append(device) - y_pred = torch.randint(0, 2, size=(10,), device=device, dtype=torch.long) - y = torch.randint(0, 2, size=(10,), device=device, dtype=torch.long) - acc.update((y_pred, y)) + for metric_device in metric_devices: + y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device) + y_preds = torch.rand(n_iters * batch_size, n_classes).to(device) - assert ( - acc._num_correct.device == metric_device - ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" + def update(_, i): + return ( + [v for v in y_preds[i * batch_size : (i + 1) * batch_size, ...]], + [v.item() for v in y_true[i * batch_size : (i + 1) * batch_size]], + ) + engine = Engine(update) -def _test_distrib_integration_list_of_tensors_or_numbers(device): - rank = idist.get_rank() + acc = Accuracy(device=metric_device) + acc.attach(engine, "acc") - def _test(n_epochs, metric_device): - metric_device = torch.device(metric_device) - n_iters = 80 - batch_size = 16 - n_classes = 10 + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) - torch.manual_seed(12 + rank) + y_true = idist.all_gather(y_true) + y_preds = idist.all_gather(y_preds) - y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device) - y_preds = torch.rand(n_iters * batch_size, n_classes).to(device) + assert ( + acc._num_correct.device == metric_device + ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" - def update(_, i): - return ( - [v for v in y_preds[i * batch_size : (i + 1) * batch_size, ...]], - [v.item() for v in y_true[i * batch_size : (i + 1) * batch_size]], - ) + assert "acc" in engine.state.metrics + res = engine.state.metrics["acc"] + if isinstance(res, torch.Tensor): + res = res.cpu().numpy() - engine = Engine(update) - - acc = Accuracy(device=metric_device) - acc.attach(engine, "acc") - - data = list(range(n_iters)) - engine.run(data=data, max_epochs=n_epochs) - - y_true = idist.all_gather(y_true) - y_preds = idist.all_gather(y_preds) - - assert ( - acc._num_correct.device == metric_device - ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" - - assert "acc" in engine.state.metrics - res = engine.state.metrics["acc"] - if isinstance(res, torch.Tensor): - res = res.cpu().numpy() - - true_res = accuracy_score(y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy()) - assert pytest.approx(res) == true_res - - metric_devices = ["cpu"] - if device.type != "xla": - metric_devices.append(idist.device()) - for metric_device in metric_devices: - for _ in range(2): - _test(n_epochs=1, metric_device=metric_device) - _test(n_epochs=2, metric_device=metric_device) - - -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") -@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="Skip if < 1.7.0") -def test_distrib_nccl_gpu(distributed_context_single_node_nccl): - device = idist.device() - _test_distrib_multilabel_input_NHW(device) - _test_distrib_integration_multiclass(device) - _test_distrib_integration_multilabel(device) - _test_distrib_accumulator_device(device) - _test_distrib_integration_list_of_tensors_or_numbers(device) - - -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="Skip if < 1.7.0") -def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): - device = idist.device() - _test_distrib_multilabel_input_NHW(device) - _test_distrib_integration_multiclass(device) - _test_distrib_integration_multilabel(device) - _test_distrib_accumulator_device(device) - _test_distrib_integration_list_of_tensors_or_numbers(device) - - -@pytest.mark.distributed -@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") -@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") -def test_distrib_hvd(gloo_hvd_executor): - device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") - nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() - - gloo_hvd_executor(_test_distrib_multilabel_input_NHW, (device,), np=nproc, do_init=True) - gloo_hvd_executor(_test_distrib_integration_multiclass, (device,), np=nproc, do_init=True) - gloo_hvd_executor(_test_distrib_integration_multilabel, (device,), np=nproc, do_init=True) - gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) - gloo_hvd_executor(_test_distrib_integration_list_of_tensors_or_numbers, (device,), np=nproc, do_init=True) - - -@pytest.mark.tpu -@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") -@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") -def test_distrib_single_device_xla(): - device = idist.device() - _test_distrib_multilabel_input_NHW(device) - _test_distrib_integration_multiclass(device) - _test_distrib_integration_multilabel(device) - _test_distrib_accumulator_device(device) - _test_distrib_integration_list_of_tensors_or_numbers(device) - - -def _test_distrib_xla_nprocs(index): - device = idist.device() - _test_distrib_multilabel_input_NHW(device) - _test_distrib_integration_multiclass(device) - _test_distrib_integration_multilabel(device) - _test_distrib_accumulator_device(device) - _test_distrib_integration_list_of_tensors_or_numbers(device) - - -@pytest.mark.tpu -@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") -@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") -def test_distrib_xla_nprocs(xmp_executor): - n = int(os.environ["NUM_TPU_WORKERS"]) - xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n) - - -@pytest.mark.multinode_distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") -def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo): - device = idist.device() - _test_distrib_multilabel_input_NHW(device) - _test_distrib_integration_multiclass(device) - _test_distrib_integration_multilabel(device) - _test_distrib_accumulator_device(device) - _test_distrib_integration_list_of_tensors_or_numbers(device) - - -@pytest.mark.multinode_distributed -@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") -@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") -def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): - device = idist.device() - _test_distrib_multilabel_input_NHW(device) - _test_distrib_integration_multiclass(device) - _test_distrib_integration_multilabel(device) - _test_distrib_accumulator_device(device) - _test_distrib_integration_list_of_tensors_or_numbers(device) + true_res = accuracy_score(y_true.cpu().numpy(), torch.argmax(y_preds, dim=1).cpu().numpy()) + assert pytest.approx(res) == true_res def test_skip_unrolling(): diff --git a/tests/ignite/metrics/test_precision.py b/tests/ignite/metrics/test_precision.py index 20583a00f00e..a865ef6ceb13 100644 --- a/tests/ignite/metrics/test_precision.py +++ b/tests/ignite/metrics/test_precision.py @@ -101,67 +101,35 @@ def ignite_average_to_scikit_average(average, data_type: str): raise ValueError(f"Wrong average parameter `{average}`") +@pytest.mark.parametrize("n_times", range(3)) @pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"]) -def test_binary_input(average): - pr = Precision(average=average) +def test_binary_input(n_times, available_device, average, test_data_binary): + pr = Precision(average=average, device=available_device) assert pr._updated is False + y_pred, y, batch_size = test_data_binary - def _test(y_pred, y, batch_size): - pr.reset() - assert pr._updated is False + pr.reset() + assert pr._updated is False - if batch_size > 1: - n_iters = y.shape[0] // batch_size + 1 - for i in range(n_iters): - idx = i * batch_size - pr.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) - else: - pr.update((y_pred, y)) + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + pr.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + else: + pr.update((y_pred, y)) - np_y = y.numpy().ravel() - np_y_pred = y_pred.numpy().ravel() - - assert pr._type == "binary" - assert pr._updated is True - assert isinstance(pr.compute(), torch.Tensor if not average else float) - pr_compute = pr.compute().numpy() if not average else pr.compute() - sk_average_parameter = ignite_average_to_scikit_average(average, "binary") - assert precision_score( - np_y, np_y_pred, average=sk_average_parameter, labels=[0, 1], zero_division=0 - ) == pytest.approx(pr_compute) - - def get_test_cases(): - test_cases = [ - # Binary accuracy on input of shape (N, 1) or (N, ) - (torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)), 1), - (torch.randint(0, 2, size=(10, 1)), torch.randint(0, 2, size=(10, 1)), 1), - # updated batches - (torch.randint(0, 2, size=(50,)), torch.randint(0, 2, size=(50,)), 16), - (torch.randint(0, 2, size=(50, 1)), torch.randint(0, 2, size=(50, 1)), 16), - # Binary accuracy on input of shape (N, L) - (torch.randint(0, 2, size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), - (torch.randint(0, 2, size=(10, 1, 5)), torch.randint(0, 2, size=(10, 1, 5)), 1), - # updated batches - (torch.randint(0, 2, size=(50, 5)), torch.randint(0, 2, size=(50, 5)), 16), - (torch.randint(0, 2, size=(50, 1, 5)), torch.randint(0, 2, size=(50, 1, 5)), 16), - # Binary accuracy on input of shape (N, H, W) - (torch.randint(0, 2, size=(10, 12, 10)), torch.randint(0, 2, size=(10, 12, 10)), 1), - (torch.randint(0, 2, size=(10, 1, 12, 10)), torch.randint(0, 2, size=(10, 1, 12, 10)), 1), - # updated batches - (torch.randint(0, 2, size=(50, 12, 10)), torch.randint(0, 2, size=(50, 12, 10)), 16), - (torch.randint(0, 2, size=(50, 1, 12, 10)), torch.randint(0, 2, size=(50, 1, 12, 10)), 16), - # Corner case with all zeros predictions - (torch.zeros(size=(10,), dtype=torch.long), torch.randint(0, 2, size=(10,)), 1), - (torch.zeros(size=(10, 1), dtype=torch.long), torch.randint(0, 2, size=(10, 1)), 1), - ] - - return test_cases - - for _ in range(5): - # check multiple random inputs as random exact occurencies are rare - test_cases = get_test_cases() - for y_pred, y, batch_size in test_cases: - _test(y_pred, y, batch_size) + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().ravel() + + assert pr._type == "binary" + assert pr._updated is True + assert isinstance(pr.compute(), torch.Tensor if not average else float) + pr_compute = pr.compute().cpu().numpy() if not average else pr.compute() + sk_average_parameter = ignite_average_to_scikit_average(average, "binary") + assert precision_score( + np_y, np_y_pred, average=sk_average_parameter, labels=[0, 1], zero_division=0 + ) == pytest.approx(pr_compute) def test_multiclass_wrong_inputs(): @@ -221,69 +189,37 @@ def test_multiclass_wrong_inputs(): pr.update((torch.rand(10, 5), torch.randint(0, 5, size=(10,)).float())) +@pytest.mark.parametrize("n_times", range(3)) @pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"]) -def test_multiclass_input(average): - pr = Precision(average=average) +def test_multiclass_input(n_times, available_device, average, test_data_multiclass): + pr = Precision(average=average, device=available_device) assert pr._updated is False - def _test(y_pred, y, batch_size): - pr.reset() - assert pr._updated is False + y_pred, y, batch_size = test_data_multiclass + pr.reset() + assert pr._updated is False - if batch_size > 1: - n_iters = y.shape[0] // batch_size + 1 - for i in range(n_iters): - idx = i * batch_size - pr.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) - else: - pr.update((y_pred, y)) + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + pr.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + else: + pr.update((y_pred, y)) + + num_classes = y_pred.shape[1] + np_y_pred = y_pred.argmax(dim=1).numpy().ravel() + np_y = y.numpy().ravel() - num_classes = y_pred.shape[1] - np_y_pred = y_pred.argmax(dim=1).numpy().ravel() - np_y = y.numpy().ravel() - - assert pr._type == "multiclass" - assert pr._updated is True - assert isinstance(pr.compute(), torch.Tensor if not average else float) - pr_compute = pr.compute().numpy() if not average else pr.compute() - sk_average_parameter = ignite_average_to_scikit_average(average, "multiclass") - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=UndefinedMetricWarning) - sk_compute = precision_score(np_y, np_y_pred, labels=range(0, num_classes), average=sk_average_parameter) - assert sk_compute == pytest.approx(pr_compute) - - def get_test_cases(): - test_cases = [ - # Multiclass input data of shape (N, ) and (N, C) - (torch.rand(10, 6), torch.randint(0, 6, size=(10,)), 1), - (torch.rand(10, 4), torch.randint(0, 4, size=(10,)), 1), - # updated batches - (torch.rand(50, 6), torch.randint(0, 6, size=(50,)), 16), - (torch.rand(50, 4), torch.randint(0, 4, size=(50,)), 16), - # Multiclass input data of shape (N, L) and (N, C, L) - (torch.rand(10, 5, 8), torch.randint(0, 5, size=(10, 8)), 1), - (torch.rand(10, 8, 12), torch.randint(0, 8, size=(10, 12)), 1), - # updated batches - (torch.rand(50, 5, 8), torch.randint(0, 5, size=(50, 8)), 16), - (torch.rand(50, 8, 12), torch.randint(0, 8, size=(50, 12)), 16), - # Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...) - (torch.rand(10, 5, 18, 16), torch.randint(0, 5, size=(10, 18, 16)), 1), - (torch.rand(10, 7, 20, 12), torch.randint(0, 7, size=(10, 20, 12)), 1), - # updated batches - (torch.rand(50, 5, 18, 16), torch.randint(0, 5, size=(50, 18, 16)), 16), - (torch.rand(50, 7, 20, 12), torch.randint(0, 7, size=(50, 20, 12)), 16), - # Corner case with all zeros predictions - (torch.zeros(size=(10, 6)), torch.randint(0, 6, size=(10,)), 1), - (torch.zeros(size=(10, 4)), torch.randint(0, 4, size=(10,)), 1), - ] - - return test_cases - - for _ in range(5): - # check multiple random inputs as random exact occurencies are rare - test_cases = get_test_cases() - for y_pred, y, batch_size in test_cases: - _test(y_pred, y, batch_size) + assert pr._type == "multiclass" + assert pr._updated is True + assert isinstance(pr.compute(), torch.Tensor if not average else float) + pr_compute = pr.compute().cpu().numpy() if not average else pr.compute() + sk_average_parameter = ignite_average_to_scikit_average(average, "multiclass") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UndefinedMetricWarning) + sk_compute = precision_score(np_y, np_y_pred, labels=range(0, num_classes), average=sk_average_parameter) + assert sk_compute == pytest.approx(pr_compute) def test_multilabel_wrong_inputs(): @@ -320,66 +256,34 @@ def to_numpy_multilabel(y): return y +@pytest.mark.parametrize("n_times", range(3)) @pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted", "samples"]) -def test_multilabel_input(average): - pr = Precision(average=average, is_multilabel=True) +def test_multilabel_input(n_times, available_device, average, test_data_multilabel): + pr = Precision(average=average, is_multilabel=True, device=available_device) assert pr._updated is False - def _test(y_pred, y, batch_size): - pr.reset() - assert pr._updated is False + y_pred, y, batch_size = test_data_multilabel + pr.reset() + assert pr._updated is False - if batch_size > 1: - n_iters = y.shape[0] // batch_size + 1 - for i in range(n_iters): - idx = i * batch_size - pr.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) - else: - pr.update((y_pred, y)) + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + pr.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + else: + pr.update((y_pred, y)) - np_y_pred = to_numpy_multilabel(y_pred) - np_y = to_numpy_multilabel(y) - - assert pr._type == "multilabel" - assert pr._updated is True - pr_compute = pr.compute().numpy() if not average else pr.compute() - sk_average_parameter = ignite_average_to_scikit_average(average, "multilabel") - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=UndefinedMetricWarning) - assert precision_score(np_y, np_y_pred, average=sk_average_parameter) == pytest.approx(pr_compute) - - def get_test_cases(): - test_cases = [ - # Multilabel input data of shape (N, C) - (torch.randint(0, 2, size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), - (torch.randint(0, 2, size=(10, 4)), torch.randint(0, 2, size=(10, 4)), 1), - # updated batches - (torch.randint(0, 2, size=(50, 5)), torch.randint(0, 2, size=(50, 5)), 16), - (torch.randint(0, 2, size=(50, 4)), torch.randint(0, 2, size=(50, 4)), 16), - # Multilabel input data of shape (N, C, L) - (torch.randint(0, 2, size=(10, 5, 10)), torch.randint(0, 2, size=(10, 5, 10)), 1), - (torch.randint(0, 2, size=(10, 4, 10)), torch.randint(0, 2, size=(10, 4, 10)), 1), - # updated batches - (torch.randint(0, 2, size=(50, 5, 10)), torch.randint(0, 2, size=(50, 5, 10)), 16), - (torch.randint(0, 2, size=(50, 4, 10)), torch.randint(0, 2, size=(50, 4, 10)), 16), - # Multilabel input data of shape (N, C, H, W) - (torch.randint(0, 2, size=(10, 5, 18, 16)), torch.randint(0, 2, size=(10, 5, 18, 16)), 1), - (torch.randint(0, 2, size=(10, 4, 20, 23)), torch.randint(0, 2, size=(10, 4, 20, 23)), 1), - # updated batches - (torch.randint(0, 2, size=(50, 5, 18, 16)), torch.randint(0, 2, size=(50, 5, 18, 16)), 16), - (torch.randint(0, 2, size=(50, 4, 20, 23)), torch.randint(0, 2, size=(50, 4, 20, 23)), 16), - # Corner case with all zeros predictions - (torch.zeros(size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), - (torch.zeros(size=(10, 4)), torch.randint(0, 2, size=(10, 4)), 1), - ] - - return test_cases - - for _ in range(5): - # check multiple random inputs as random exact occurencies are rare - test_cases = get_test_cases() - for y_pred, y, batch_size in test_cases: - _test(y_pred, y, batch_size) + np_y_pred = to_numpy_multilabel(y_pred) + np_y = to_numpy_multilabel(y) + + assert pr._type == "multilabel" + assert pr._updated is True + pr_compute = pr.compute().cpu().numpy() if not average else pr.compute() + sk_average_parameter = ignite_average_to_scikit_average(average, "multilabel") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UndefinedMetricWarning) + assert precision_score(np_y, np_y_pred, average=sk_average_parameter) == pytest.approx(pr_compute) @pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"]) diff --git a/tests/ignite/metrics/test_recall.py b/tests/ignite/metrics/test_recall.py index 389f288a34ca..d813c6ac434e 100644 --- a/tests/ignite/metrics/test_recall.py +++ b/tests/ignite/metrics/test_recall.py @@ -104,65 +104,34 @@ def ignite_average_to_scikit_average(average, data_type: str): raise ValueError(f"Wrong average parameter `{average}`") +@pytest.mark.parametrize("n_times", range(3)) @pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"]) -def test_binary_input(average): - re = Recall(average=average) +def test_binary_input(n_times, available_device, average, test_data_binary): + re = Recall(average=average, device=available_device) assert re._updated is False - def _test(y_pred, y, batch_size): - re.reset() - assert re._updated is False + y_pred, y, batch_size = test_data_binary - if batch_size > 1: - n_iters = y.shape[0] // batch_size + 1 - for i in range(n_iters): - idx = i * batch_size - re.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) - else: - re.update((y_pred, y)) - - np_y = y.numpy().ravel() - np_y_pred = y_pred.numpy().ravel() - - assert re._type == "binary" - assert re._updated is True - assert isinstance(re.compute(), torch.Tensor if not average else float) - re_compute = re.compute().numpy() if not average else re.compute() - sk_average_parameter = ignite_average_to_scikit_average(average, "binary") - assert recall_score(np_y, np_y_pred, average=sk_average_parameter, labels=[0, 1]) == pytest.approx(re_compute) - - def get_test_cases(): - test_cases = [ - # Binary accuracy on input of shape (N, 1) or (N, ) - (torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)), 1), - (torch.randint(0, 2, size=(10, 1)), torch.randint(0, 2, size=(10, 1)), 1), - # updated batches - (torch.randint(0, 2, size=(50,)), torch.randint(0, 2, size=(50,)), 16), - (torch.randint(0, 2, size=(50, 1)), torch.randint(0, 2, size=(50, 1)), 16), - # Binary accuracy on input of shape (N, L) - (torch.randint(0, 2, size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), - (torch.randint(0, 2, size=(10, 1, 5)), torch.randint(0, 2, size=(10, 1, 5)), 1), - # updated batches - (torch.randint(0, 2, size=(50, 5)), torch.randint(0, 2, size=(50, 5)), 16), - (torch.randint(0, 2, size=(50, 1, 5)), torch.randint(0, 2, size=(50, 1, 5)), 16), - # Binary accuracy on input of shape (N, H, W) - (torch.randint(0, 2, size=(10, 12, 10)), torch.randint(0, 2, size=(10, 12, 10)), 1), - (torch.randint(0, 2, size=(10, 1, 12, 10)), torch.randint(0, 2, size=(10, 1, 12, 10)), 1), - # updated batches - (torch.randint(0, 2, size=(50, 12, 10)), torch.randint(0, 2, size=(50, 12, 10)), 16), - (torch.randint(0, 2, size=(50, 1, 12, 10)), torch.randint(0, 2, size=(50, 1, 12, 10)), 16), - # Corner case with all zeros predictions - (torch.zeros(size=(10,), dtype=torch.long), torch.randint(0, 2, size=(10,)), 1), - (torch.zeros(size=(10, 1), dtype=torch.long), torch.randint(0, 2, size=(10, 1)), 1), - ] - - return test_cases - - for _ in range(5): - # check multiple random inputs as random exact occurencies are rare - test_cases = get_test_cases() - for y_pred, y, batch_size in test_cases: - _test(y_pred, y, batch_size) + re.reset() + assert re._updated is False + + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + re.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + else: + re.update((y_pred, y)) + + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().ravel() + + assert re._type == "binary" + assert re._updated is True + assert isinstance(re.compute(), torch.Tensor if not average else float) + re_compute = re.compute().cpu().numpy() if not average else re.compute() + sk_average_parameter = ignite_average_to_scikit_average(average, "binary") + assert recall_score(np_y, np_y_pred, average=sk_average_parameter, labels=[0, 1]) == pytest.approx(re_compute) def test_multiclass_wrong_inputs(): @@ -222,69 +191,37 @@ def test_multiclass_wrong_inputs(): re.update((torch.rand(10, 5), torch.randint(0, 5, size=(10,)).float())) +@pytest.mark.parametrize("n_times", range(3)) @pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"]) -def test_multiclass_input(average): - re = Recall(average=average) +def test_multiclass_input(n_times, available_device, average, test_data_multiclass): + re = Recall(average=average, device=available_device) assert re._updated is False - def _test(y_pred, y, batch_size): - re.reset() - assert re._updated is False + y_pred, y, batch_size = test_data_multiclass + re.reset() + assert re._updated is False - if batch_size > 1: - n_iters = y.shape[0] // batch_size + 1 - for i in range(n_iters): - idx = i * batch_size - re.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) - else: - re.update((y_pred, y)) - - num_classes = y_pred.shape[1] - np_y_pred = y_pred.argmax(dim=1).numpy().ravel() - np_y = y.numpy().ravel() - - assert re._type == "multiclass" - assert re._updated is True - assert isinstance(re.compute(), torch.Tensor if not average else float) - re_compute = re.compute().numpy() if not average else re.compute() - sk_average_parameter = ignite_average_to_scikit_average(average, "multiclass") - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=UndefinedMetricWarning) - sk_compute = recall_score(np_y, np_y_pred, labels=range(0, num_classes), average=sk_average_parameter) - assert sk_compute == pytest.approx(re_compute) - - def get_test_cases(): - test_cases = [ - # Multiclass input data of shape (N, ) and (N, C) - (torch.rand(10, 6), torch.randint(0, 6, size=(10,)), 1), - (torch.rand(10, 4), torch.randint(0, 4, size=(10,)), 1), - # updated batches - (torch.rand(50, 6), torch.randint(0, 6, size=(50,)), 16), - (torch.rand(50, 4), torch.randint(0, 4, size=(50,)), 16), - # Multiclass input data of shape (N, L) and (N, C, L) - (torch.rand(10, 5, 8), torch.randint(0, 5, size=(10, 8)), 1), - (torch.rand(10, 8, 12), torch.randint(0, 8, size=(10, 12)), 1), - # updated batches - (torch.rand(50, 5, 8), torch.randint(0, 5, size=(50, 8)), 16), - (torch.rand(50, 8, 12), torch.randint(0, 8, size=(50, 12)), 16), - # Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...) - (torch.rand(10, 5, 18, 16), torch.randint(0, 5, size=(10, 18, 16)), 1), - (torch.rand(10, 7, 20, 12), torch.randint(0, 7, size=(10, 20, 12)), 1), - # updated batches - (torch.rand(50, 5, 18, 16), torch.randint(0, 5, size=(50, 18, 16)), 16), - (torch.rand(50, 7, 20, 12), torch.randint(0, 7, size=(50, 20, 12)), 16), - # Corner case with all zeros predictions - (torch.zeros(size=(10, 6)), torch.randint(0, 6, size=(10,)), 1), - (torch.zeros(size=(10, 4)), torch.randint(0, 4, size=(10,)), 1), - ] - - return test_cases - - for _ in range(5): - # check multiple random inputs as random exact occurencies are rare - test_cases = get_test_cases() - for y_pred, y, batch_size in test_cases: - _test(y_pred, y, batch_size) + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + re.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + else: + re.update((y_pred, y)) + + num_classes = y_pred.shape[1] + np_y_pred = y_pred.argmax(dim=1).numpy().ravel() + np_y = y.numpy().ravel() + + assert re._type == "multiclass" + assert re._updated is True + assert isinstance(re.compute(), torch.Tensor if not average else float) + re_compute = re.compute().cpu().numpy() if not average else re.compute() + sk_average_parameter = ignite_average_to_scikit_average(average, "multiclass") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UndefinedMetricWarning) + sk_compute = recall_score(np_y, np_y_pred, labels=range(0, num_classes), average=sk_average_parameter) + assert sk_compute == pytest.approx(re_compute) def test_multilabel_wrong_inputs(): @@ -321,66 +258,35 @@ def to_numpy_multilabel(y): return y -@pytest.mark.parametrize("average", [None, False, "macro", "micro", "samples"]) -def test_multilabel_input(average): - re = Recall(average=average, is_multilabel=True) +@pytest.mark.parametrize("n_times", range(3)) +@pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted", "samples"]) +def test_multilabel_input(n_times, available_device, average, test_data_multilabel): + + re = Recall(average=average, is_multilabel=True, device=available_device) + assert re._updated is False + + y_pred, y, batch_size = test_data_multilabel + re.reset() assert re._updated is False - def _test(y_pred, y, batch_size): - re.reset() - assert re._updated is False + if batch_size > 1: + n_iters = y.shape[0] // batch_size + 1 + for i in range(n_iters): + idx = i * batch_size + re.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + else: + re.update((y_pred, y)) + + np_y_pred = to_numpy_multilabel(y_pred) + np_y = to_numpy_multilabel(y) - if batch_size > 1: - n_iters = y.shape[0] // batch_size + 1 - for i in range(n_iters): - idx = i * batch_size - re.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) - else: - re.update((y_pred, y)) - - np_y_pred = to_numpy_multilabel(y_pred) - np_y = to_numpy_multilabel(y) - - assert re._type == "multilabel" - assert re._updated is True - re_compute = re.compute().numpy() if not average else re.compute() - sk_average_parameter = ignite_average_to_scikit_average(average, "multilabel") - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=UndefinedMetricWarning) - assert recall_score(np_y, np_y_pred, average=sk_average_parameter) == pytest.approx(re_compute) - - def get_test_cases(): - test_cases = [ - # Multilabel input data of shape (N, C) - (torch.randint(0, 2, size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), - (torch.randint(0, 2, size=(10, 4)), torch.randint(0, 2, size=(10, 4)), 1), - # updated batches - (torch.randint(0, 2, size=(50, 5)), torch.randint(0, 2, size=(50, 5)), 16), - (torch.randint(0, 2, size=(50, 4)), torch.randint(0, 2, size=(50, 4)), 16), - # Multilabel input data of shape (N, H, W) - (torch.randint(0, 2, size=(10, 5, 10)), torch.randint(0, 2, size=(10, 5, 10)), 1), - (torch.randint(0, 2, size=(10, 4, 10)), torch.randint(0, 2, size=(10, 4, 10)), 1), - # updated batches - (torch.randint(0, 2, size=(50, 5, 10)), torch.randint(0, 2, size=(50, 5, 10)), 16), - (torch.randint(0, 2, size=(50, 4, 10)), torch.randint(0, 2, size=(50, 4, 10)), 16), - # Multilabel input data of shape (N, C, H, W, ...) - (torch.randint(0, 2, size=(10, 5, 18, 16)), torch.randint(0, 2, size=(10, 5, 18, 16)), 1), - (torch.randint(0, 2, size=(10, 4, 20, 23)), torch.randint(0, 2, size=(10, 4, 20, 23)), 1), - # updated batches - (torch.randint(0, 2, size=(50, 5, 18, 16)), torch.randint(0, 2, size=(50, 5, 18, 16)), 16), - (torch.randint(0, 2, size=(50, 4, 20, 23)), torch.randint(0, 2, size=(50, 4, 20, 23)), 16), - # Corner case with all zeros predictions - (torch.zeros(size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), - (torch.zeros(size=(10, 4)), torch.randint(0, 2, size=(10, 4)), 1), - ] - - return test_cases - - for _ in range(5): - # check multiple random inputs as random exact occurencies are rare - test_cases = get_test_cases() - for y_pred, y, batch_size in test_cases: - _test(y_pred, y, batch_size) + assert re._type == "multilabel" + assert re._updated is True + re_compute = re.compute().cpu().numpy() if not average else re.compute() + sk_average_parameter = ignite_average_to_scikit_average(average, "multilabel") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UndefinedMetricWarning) + assert recall_score(np_y, np_y_pred, average=sk_average_parameter) == pytest.approx(re_compute) @pytest.mark.parametrize("average", [None, False, "macro", "micro", "weighted"])