@@ -38,6 +38,10 @@ def __init__(
38
38
super (_BasePrecisionRecall , self ).__init__ (
39
39
output_transform = output_transform , is_multilabel = is_multilabel , device = device , skip_unrolling = skip_unrolling
40
40
)
41
+ # MPS framework doesn't support float64, should use float32
42
+ self ._double_dtype = torch .float64
43
+ if self ._device .type == "mps" :
44
+ self ._double_dtype = torch .float32
41
45
42
46
def _check_type (self , output : Sequence [torch .Tensor ]) -> None :
43
47
super ()._check_type (output )
@@ -81,8 +85,8 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
81
85
y = torch .transpose (y , 1 , - 1 ).reshape (- 1 , num_labels )
82
86
83
87
# Convert from int cuda/cpu to double on self._device
84
- y_pred = y_pred .to (dtype = torch . float64 , device = self ._device )
85
- y = y .to (dtype = torch . float64 , device = self ._device )
88
+ y_pred = y_pred .to (dtype = self . _double_dtype , device = self ._device )
89
+ y = y .to (dtype = self . _double_dtype , device = self ._device )
86
90
correct = y * y_pred
87
91
88
92
return y_pred , y , correct
0 commit comments