Skip to content

Commit 1c0818f

Browse files
authored
Updated accuracy / precision tests (#3333)
Added metrics devices: cpu, cuda, mps if available
1 parent b950b46 commit 1c0818f

File tree

9 files changed

+482
-769
lines changed

9 files changed

+482
-769
lines changed

tests/ignite/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import torch
22

3-
4-
def cpu_and_maybe_cuda():
5-
return ("cpu",) + (("cuda",) if torch.cuda.is_available() else ())
3+
from ignite.distributed.comp_models.base import _torch_version_gt_112
64

75

86
def is_mps_available_and_functional():
7+
if not _torch_version_gt_112:
8+
return False
9+
910
if not torch.backends.mps.is_available():
1011
return False
1112
try:

tests/ignite/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.distributed as dist
1414

1515
import ignite.distributed as idist
16+
from tests.ignite import is_mps_available_and_functional
1617

1718

1819
def pytest_addoption(parser):
@@ -72,6 +73,9 @@ def term_handler():
7273
params=[
7374
"cpu",
7475
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no CUDA support")),
76+
pytest.param(
77+
"mps", marks=pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if no MPS support")
78+
),
7579
]
7680
)
7781
def available_device(request):

tests/ignite/distributed/test_auto.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import ignite.distributed as idist
1414
from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler
15-
from ignite.distributed.comp_models.base import _torch_version_gt_112
1615
from tests.ignite import is_mps_available_and_functional
1716

1817

@@ -181,10 +180,7 @@ def _test_auto_model_optimizer(ws, device):
181180
assert optimizer.backward_passes_per_step == backward_passes_per_step
182181

183182

184-
@pytest.mark.skipif(
185-
(not _torch_version_gt_112) or (torch.backends.mps.is_available() and not is_mps_available_and_functional()),
186-
reason="Skip if MPS not functional",
187-
)
183+
@pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if MPS not functional")
188184
def test_auto_methods_no_dist():
189185
_test_auto_dataloader(1, 1, batch_size=1)
190186
_test_auto_dataloader(1, 1, batch_size=10, num_workers=2)

tests/ignite/distributed/test_launcher.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from packaging.version import Version
99

1010
import ignite.distributed as idist
11-
from ignite.distributed.comp_models.base import _torch_version_gt_112
1211
from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support
1312
from tests.ignite import is_mps_available_and_functional
1413

@@ -56,10 +55,7 @@ def execute(cmd, env=None):
5655
return str(process.stdout.read()) + str(process.stderr.read())
5756

5857

59-
@pytest.mark.skipif(
60-
(not _torch_version_gt_112) or (torch.backends.mps.is_available() and not is_mps_available_and_functional()),
61-
reason="Skip if MPS not functional",
62-
)
58+
@pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if MPS not functional")
6359
def test_check_idist_parallel_no_dist(exec_filepath):
6460
cmd = [sys.executable, "-u", exec_filepath]
6561
out = execute(cmd)

tests/ignite/engine/test_create_supervised.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch.optim import SGD
1313

1414
import ignite.distributed as idist
15-
from ignite.distributed.comp_models.base import _torch_version_gt_112
1615
from ignite.engine import (
1716
_check_arg,
1817
create_supervised_evaluator,
@@ -488,7 +487,7 @@ def test_create_supervised_trainer_on_cuda():
488487
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)
489488

490489

491-
@pytest.mark.skipif(not (_torch_version_gt_112 and is_mps_available_and_functional()), reason="Skip if no MPS")
490+
@pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if no MPS")
492491
def test_create_supervised_trainer_on_mps():
493492
model_device = trainer_device = "mps"
494493
_test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device)
@@ -669,14 +668,14 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu():
669668
_test_mocked_supervised_evaluator(evaluator_device="cuda")
670669

671670

672-
@pytest.mark.skipif(not (_torch_version_gt_112 and is_mps_available_and_functional()), reason="Skip if no MPS")
671+
@pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if no MPS")
673672
def test_create_supervised_evaluator_on_mps():
674673
model_device = evaluator_device = "mps"
675674
_test_create_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)
676675
_test_mocked_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)
677676

678677

679-
@pytest.mark.skipif(not (_torch_version_gt_112 and is_mps_available_and_functional()), reason="Skip if no MPS")
678+
@pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if no MPS")
680679
def test_create_supervised_evaluator_on_mps_with_model_on_cpu():
681680
_test_create_supervised_evaluator(evaluator_device="mps")
682681
_test_mocked_supervised_evaluator(evaluator_device="mps")

tests/ignite/metrics/conftest.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import pytest
2+
import torch
3+
4+
5+
@pytest.fixture(params=range(14))
6+
def test_data_binary(request):
7+
return [
8+
# Binary accuracy on input of shape (N, 1) or (N, )
9+
(torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)), 1),
10+
(torch.randint(0, 2, size=(10, 1)), torch.randint(0, 2, size=(10, 1)), 1),
11+
# updated batches
12+
(torch.randint(0, 2, size=(50,)), torch.randint(0, 2, size=(50,)), 16),
13+
(torch.randint(0, 2, size=(50, 1)), torch.randint(0, 2, size=(50, 1)), 16),
14+
# Binary accuracy on input of shape (N, L)
15+
(torch.randint(0, 2, size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1),
16+
(torch.randint(0, 2, size=(10, 1, 5)), torch.randint(0, 2, size=(10, 1, 5)), 1),
17+
# updated batches
18+
(torch.randint(0, 2, size=(50, 5)), torch.randint(0, 2, size=(50, 5)), 16),
19+
(torch.randint(0, 2, size=(50, 1, 5)), torch.randint(0, 2, size=(50, 1, 5)), 16),
20+
# Binary accuracy on input of shape (N, H, W)
21+
(torch.randint(0, 2, size=(10, 12, 10)), torch.randint(0, 2, size=(10, 12, 10)), 1),
22+
(torch.randint(0, 2, size=(10, 1, 12, 10)), torch.randint(0, 2, size=(10, 1, 12, 10)), 1),
23+
# updated batches
24+
(torch.randint(0, 2, size=(50, 12, 10)), torch.randint(0, 2, size=(50, 12, 10)), 16),
25+
(torch.randint(0, 2, size=(50, 1, 12, 10)), torch.randint(0, 2, size=(50, 1, 12, 10)), 16),
26+
# Corner case with all zeros predictions
27+
(torch.zeros(size=(10,), dtype=torch.long), torch.randint(0, 2, size=(10,)), 1),
28+
(torch.zeros(size=(10, 1), dtype=torch.long), torch.randint(0, 2, size=(10, 1)), 1),
29+
][request.param]
30+
31+
32+
@pytest.fixture(params=range(14))
33+
def test_data_multiclass(request):
34+
return [
35+
# Multiclass input data of shape (N, ) and (N, C)
36+
(torch.rand(10, 6), torch.randint(0, 6, size=(10,)), 1),
37+
(torch.rand(10, 4), torch.randint(0, 4, size=(10,)), 1),
38+
# updated batches
39+
(torch.rand(50, 6), torch.randint(0, 6, size=(50,)), 16),
40+
(torch.rand(50, 4), torch.randint(0, 4, size=(50,)), 16),
41+
# Multiclass input data of shape (N, L) and (N, C, L)
42+
(torch.rand(10, 5, 8), torch.randint(0, 5, size=(10, 8)), 1),
43+
(torch.rand(10, 8, 12), torch.randint(0, 8, size=(10, 12)), 1),
44+
# updated batches
45+
(torch.rand(50, 5, 8), torch.randint(0, 5, size=(50, 8)), 16),
46+
(torch.rand(50, 8, 12), torch.randint(0, 8, size=(50, 12)), 16),
47+
# Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...)
48+
(torch.rand(10, 5, 18, 16), torch.randint(0, 5, size=(10, 18, 16)), 1),
49+
(torch.rand(10, 7, 20, 12), torch.randint(0, 7, size=(10, 20, 12)), 1),
50+
# updated batches
51+
(torch.rand(50, 5, 18, 16), torch.randint(0, 5, size=(50, 18, 16)), 16),
52+
(torch.rand(50, 7, 20, 12), torch.randint(0, 7, size=(50, 20, 12)), 16),
53+
# Corner case with all zeros predictions
54+
(torch.zeros(size=(10, 6)), torch.randint(0, 6, size=(10,)), 1),
55+
(torch.zeros(size=(10, 4)), torch.randint(0, 4, size=(10,)), 1),
56+
][request.param]
57+
58+
59+
@pytest.fixture(params=range(14))
60+
def test_data_multilabel(request):
61+
return [
62+
# Multilabel input data of shape (N, C)
63+
(torch.randint(0, 2, size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1),
64+
(torch.randint(0, 2, size=(10, 4)), torch.randint(0, 2, size=(10, 4)), 1),
65+
# updated batches
66+
(torch.randint(0, 2, size=(50, 5)), torch.randint(0, 2, size=(50, 5)), 16),
67+
(torch.randint(0, 2, size=(50, 4)), torch.randint(0, 2, size=(50, 4)), 16),
68+
# Multilabel input data of shape (N, C, L)
69+
(torch.randint(0, 2, size=(10, 5, 10)), torch.randint(0, 2, size=(10, 5, 10)), 1),
70+
(torch.randint(0, 2, size=(10, 4, 10)), torch.randint(0, 2, size=(10, 4, 10)), 1),
71+
# updated batches
72+
(torch.randint(0, 2, size=(50, 5, 10)), torch.randint(0, 2, size=(50, 5, 10)), 16),
73+
(torch.randint(0, 2, size=(50, 4, 10)), torch.randint(0, 2, size=(50, 4, 10)), 16),
74+
# Multilabel input data of shape (N, C, H, W)
75+
(torch.randint(0, 2, size=(10, 5, 18, 16)), torch.randint(0, 2, size=(10, 5, 18, 16)), 1),
76+
(torch.randint(0, 2, size=(10, 4, 20, 23)), torch.randint(0, 2, size=(10, 4, 20, 23)), 1),
77+
# updated batches
78+
(torch.randint(0, 2, size=(50, 5, 18, 16)), torch.randint(0, 2, size=(50, 5, 18, 16)), 16),
79+
(torch.randint(0, 2, size=(50, 4, 20, 23)), torch.randint(0, 2, size=(50, 4, 20, 23)), 16),
80+
# Corner case with all zeros predictions
81+
(torch.zeros(size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1),
82+
(torch.zeros(size=(10, 4)), torch.randint(0, 2, size=(10, 4)), 1),
83+
][request.param]

0 commit comments

Comments
 (0)