Skip to content

Commit e81aed3

Browse files
committed
Fixed failing tests
1 parent 2298ae0 commit e81aed3

File tree

4 files changed

+59
-45
lines changed

4 files changed

+59
-45
lines changed

ignite/metrics/ssim.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def reset(self) -> None:
120120
self._num_examples = 0
121121

122122
def _uniform(self, kernel_size: int) -> torch.Tensor:
123-
kernel = torch.zeros(kernel_size)
123+
kernel = torch.zeros(kernel_size, device=self._device)
124124

125125
start_uniform_index = max(kernel_size // 2 - 2, 0)
126126
end_uniform_index = min(kernel_size // 2 + 3, kernel_size)
@@ -146,10 +146,7 @@ def _gaussian_or_uniform_kernel(self, kernel_size: Sequence[int], sigma: Sequenc
146146

147147
return torch.matmul(kernel_x.t(), kernel_y) # (kernel_size, 1) * (1, kernel_size)
148148

149-
@reinit__is_reduced
150-
def update(self, output: Sequence[torch.Tensor]) -> None:
151-
y_pred, y = output[0].detach(), output[1].detach()
152-
149+
def _check_type_and_shape(self, y_pred, y):
153150
if y_pred.dtype != y.dtype:
154151
raise TypeError(
155152
f"Expected y_pred and y to have the same data type. Got y_pred: {y_pred.dtype} and y: {y.dtype}."
@@ -165,6 +162,12 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
165162
f"Expected y_pred and y to have BxCxHxW shape. Got y_pred: {y_pred.shape} and y: {y.shape}."
166163
)
167164

165+
@reinit__is_reduced
166+
def update(self, output: Sequence[torch.Tensor]) -> None:
167+
y_pred, y = output[0].detach(), output[1].detach()
168+
169+
self._check_type_and_shape(y_pred, y)
170+
168171
# converts potential integer tensor to fp
169172
if not y.is_floating_point():
170173
y = y.float()
@@ -213,7 +216,15 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
213216
b2 = sigma_pred_sq + sigma_target_sq + self.c2
214217

215218
ssim_idx = (a1 * a2) / (b1 * b2)
216-
self._sum_of_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=self._double_dtype).sum().to(device=self._device)
219+
220+
# In case when ssim_idx can be MPS tensor and self._device is not MPS
221+
# self._double_dtype is float64.
222+
# As MPS does not support float64 we should set dtype to float32
223+
double_dtype = self._double_dtype
224+
if ssim_idx.device.type == "mps" and self._double_dtype == torch.float64:
225+
double_dtype = torch.float32
226+
227+
self._sum_of_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=double_dtype).sum().to(device=self._device)
217228

218229
self._num_examples += y.shape[0]
219230

tests/ignite/conftest.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,25 @@ def term_handler():
6969
yield # Just pass through if SIGTERM isn't supported or we are not in the main thread
7070

7171

72-
@pytest.fixture(
73-
params=[
74-
"cpu",
75-
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no CUDA support")),
76-
pytest.param(
77-
"mps", marks=pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if no MPS support")
78-
),
79-
]
80-
)
72+
available_devices_list = [
73+
"cpu",
74+
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no CUDA support")),
75+
pytest.param(
76+
"mps", marks=pytest.mark.skipif(not is_mps_available_and_functional(), reason="Skip if no MPS support")
77+
),
78+
]
79+
80+
81+
@pytest.fixture(params=available_devices_list)
8182
def available_device(request):
8283
return request.param
8384

8485

86+
@pytest.fixture(params=available_devices_list)
87+
def available_device2(request):
88+
return request.param
89+
90+
8591
@pytest.fixture()
8692
def dirname():
8793
path = Path(tempfile.mkdtemp())

tests/ignite/metrics/gan/test_fid.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ def test_compute_fid_from_features(available_device):
6767
mu1, sigma1 = train_samples.mean(axis=0), cov(train_samples, rowvar=False)
6868
mu2, sigma2 = test_samples.mean(axis=0), cov(test_samples, rowvar=False)
6969

70+
tol = 1e-4 if available_device == "mps" else 1e-5
7071
assert (
71-
pytest.approx(pytorch_fid_score.calculate_frechet_distance(mu1, sigma1, mu2, sigma2), rel=1e-5)
72+
pytest.approx(pytorch_fid_score.calculate_frechet_distance(mu1, sigma1, mu2, sigma2), rel=tol)
7273
== fid_scorer.compute()
7374
)
7475

@@ -142,22 +143,22 @@ def test_statistics(available_device):
142143
fid_scorer.update([train_samples[:5], test_samples[:5]])
143144
fid_scorer.update([train_samples[5:], test_samples[5:]])
144145

145-
mu1, sigma1 = train_samples.mean(axis=0), torch.tensor(cov(train_samples, rowvar=False))
146-
mu2, sigma2 = test_samples.mean(axis=0), torch.tensor(cov(test_samples, rowvar=False))
146+
mu1 = train_samples.mean(axis=0, dtype=torch.float64)
147+
sigma1 = torch.tensor(cov(train_samples, rowvar=False), dtype=torch.float64)
148+
mu2 = test_samples.mean(axis=0, dtype=torch.float64)
149+
sigma2 = torch.tensor(cov(test_samples, rowvar=False), dtype=torch.float64)
147150

148151
fid_mu1 = fid_scorer._train_total / fid_scorer._num_examples
149152
fid_sigma1 = fid_scorer._get_covariance(fid_scorer._train_sigma, fid_scorer._train_total)
150153

151154
fid_mu2 = fid_scorer._test_total / fid_scorer._num_examples
152155
fid_sigma2 = fid_scorer._get_covariance(fid_scorer._test_sigma, fid_scorer._test_total)
153156

154-
assert torch.isclose(mu1.double(), fid_mu1.cpu()).all()
155-
for cov1, cov2 in zip(sigma1, fid_sigma1):
156-
assert torch.isclose(cov1.double(), cov2.cpu(), rtol=1e-04, atol=1e-04).all()
157+
assert torch.allclose(mu1, fid_mu1.to(mu1))
158+
assert torch.allclose(sigma1, fid_sigma1.to(sigma1), rtol=1e-04, atol=1e-04)
157159

158-
assert torch.isclose(mu2.double(), fid_mu2.cpu()).all()
159-
for cov1, cov2 in zip(sigma2, fid_sigma2):
160-
assert torch.isclose(cov1.double(), cov2.cpu(), rtol=1e-04, atol=1e-04).all()
160+
assert torch.allclose(mu2, fid_mu2.to(mu2))
161+
assert torch.allclose(sigma2, fid_sigma2.to(sigma2), rtol=1e-04, atol=1e-04)
161162

162163

163164
def _test_distrib_integration(device):

tests/ignite/metrics/test_ssim.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -128,18 +128,9 @@ def compare_ssim_ignite_skiimg(
128128
assert np.allclose(ignite_ssim, skimg_ssim, atol=precision)
129129

130130

131-
@pytest.mark.parametrize(
132-
"metric_device, y_pred_device",
133-
[
134-
[torch.device("cpu"), torch.device("cpu")],
135-
[torch.device("cpu"), torch.device("cuda")],
136-
[torch.device("cuda"), torch.device("cpu")],
137-
[torch.device("cuda"), torch.device("cuda")],
138-
],
139-
)
140-
def test_ssim_device(available_device, metric_device, y_pred_device):
141-
if available_device == "cpu":
142-
pytest.skip("This test requires a cuda device.")
131+
def test_ssim_device(available_device, available_device2):
132+
metric_device = available_device
133+
y_pred_device = available_device2
143134

144135
data_range = 1.0
145136
sigma = 1.5
@@ -150,26 +141,28 @@ def test_ssim_device(available_device, metric_device, y_pred_device):
150141
y_pred = torch.rand(shape, device=y_pred_device)
151142
y = y_pred * 0.8
152143

153-
if metric_device == torch.device("cuda") and y_pred_device == torch.device("cpu"):
154-
with pytest.warns(UserWarning):
144+
if metric_device != y_pred_device and y_pred_device == "cpu":
145+
with pytest.warns(
146+
UserWarning,
147+
match=r"y_pred tensor is on cpu device but previous computation was on another device",
148+
):
155149
ssim.update((y_pred, y))
156150
else:
157151
ssim.update((y_pred, y))
158152

159-
if metric_device == torch.device("cuda") or y_pred_device == torch.device("cuda"):
160-
# A tensor will always have the device index set
161-
excepted_device = torch.device("cuda:0")
153+
if y_pred_device != "cpu" and metric_device == "cpu":
154+
excepted_device = y_pred_device
162155
else:
163-
excepted_device = torch.device("cpu")
156+
excepted_device = metric_device
164157

165-
assert ssim._kernel.device == excepted_device
158+
assert ssim._kernel.device.type == excepted_device
166159

167160

168161
def test_ssim_variable_batchsize(available_device):
169162
# Checks https://github.com/pytorch/ignite/issues/2532
170163
sigma = 1.5
171164
data_range = 1.0
172-
ssim = SSIM(data_range=data_range, sigma=sigma)
165+
ssim = SSIM(data_range=data_range, sigma=sigma, device=available_device)
173166

174167
y_preds = [
175168
torch.rand(12, 3, 28, 28, device=available_device),
@@ -209,11 +202,14 @@ def test_ssim_variable_channel(available_device):
209202
@pytest.mark.parametrize(
210203
"dtype, precision", [(torch.bfloat16, 2e-3), (torch.float16, 4e-4), (torch.float32, 2e-5), (torch.float64, 2e-5)]
211204
)
212-
def test_cuda_ssim_dtypes(available_device, dtype, precision):
205+
def test_ssim_dtypes(available_device, dtype, precision):
213206
# Checks https://github.com/pytorch/ignite/pull/3034
214207
if available_device == "cpu" and dtype in [torch.float16, torch.bfloat16]:
215208
pytest.skip(reason=f"Unsupported dtype {dtype} on CPU device")
216209

210+
if available_device == "mps" and dtype in [torch.float64]:
211+
pytest.skip(reason=f"Unsupported dtype {dtype} on MPS device")
212+
217213
shape = (12, 3, 28, 28)
218214

219215
y_pred = torch.rand(shape, device=available_device, dtype=dtype)

0 commit comments

Comments
 (0)