Skip to content

Commit 0444933

Browse files
Resolve MPS's lack of cummax
1 parent 3658f95 commit 0444933

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

ignite/metrics/vision/object_detection_average_precision_recall.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
23

34
import torch
@@ -125,7 +126,7 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo
125126
class_mean=None,
126127
)
127128
precision = torch.double if torch.device(device).type != "mps" else torch.float32
128-
self.rec_thresholds = self.rec_thresholds.to(device=device, dtype=precision)
129+
self.rec_thresholds = cast(torch.Tensor, self.rec_thresholds).to(device=device, dtype=precision)
129130

130131
@reinit__is_reduced
131132
def reset(self) -> None:
@@ -234,7 +235,10 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens
234235
Returns:
235236
average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions.
236237
"""
238+
mps_cpu_fallback = os.environ.get("PYTORCH_ENABLE_MPS_FALLBACK", "0")
239+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
237240
precision_integrand = precision.flip(-1).cummax(dim=-1).values.flip(-1)
241+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = mps_cpu_fallback
238242
rec_thresholds = cast(torch.Tensor, self.rec_thresholds).repeat((*recall.shape[:-1], 1))
239243
rec_thresh_indices = (
240244
torch.searchsorted(recall, rec_thresholds)

0 commit comments

Comments
 (0)