Skip to content

Commit c78fa53

Browse files
committed
Use float32 in metrics when metric device is MPS
Related to #3326
1 parent 1c0818f commit c78fa53

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

ignite/metrics/multilabel_confusion_matrix.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,11 @@ def compute(self) -> torch.Tensor:
136136
raise NotComputableError("Confusion matrix must have at least one example before it can be computed.")
137137

138138
if self.normalized:
139-
conf = self.confusion_matrix.to(dtype=torch.float64)
139+
# MPS framework doesn't support float64, should use float32
140+
double_dtype = torch.float64
141+
if self.confusion_matrix.device.type == "mps":
142+
double_dtype = torch.float32
143+
conf = self.confusion_matrix.to(dtype=double_dtype)
140144
sums = conf.sum(dim=(1, 2))
141145
return conf / sums[:, None, None]
142146

ignite/metrics/precision.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def __init__(
3838
super(_BasePrecisionRecall, self).__init__(
3939
output_transform=output_transform, is_multilabel=is_multilabel, device=device, skip_unrolling=skip_unrolling
4040
)
41+
# MPS framework doesn't support float64, should use float32
42+
self._double_dtype = torch.float64
43+
if self._device.type == "mps":
44+
self._double_dtype = torch.float32
4145

4246
def _check_type(self, output: Sequence[torch.Tensor]) -> None:
4347
super()._check_type(output)
@@ -81,8 +85,8 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
8185
y = torch.transpose(y, 1, -1).reshape(-1, num_labels)
8286

8387
# Convert from int cuda/cpu to double on self._device
84-
y_pred = y_pred.to(dtype=torch.float64, device=self._device)
85-
y = y.to(dtype=torch.float64, device=self._device)
88+
y_pred = y_pred.to(dtype=self._double_dtype, device=self._device)
89+
y = y.to(dtype=self._double_dtype, device=self._device)
8690
correct = y * y_pred
8791

8892
return y_pred, y, correct

0 commit comments

Comments
 (0)