We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9528d96 commit c786d12Copy full SHA for c786d12
test/test_ops.py
@@ -557,8 +557,10 @@ def test_nms_cuda_float16(self):
557
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
558
assert_equal(keep32, keep16)
559
560
- def test_batched_nms_implementations(self):
+ @pytest.mark.parametrize("seed", range(10))
561
+ def test_batched_nms_implementations(self, seed):
562
"""Make sure that both implementations of batched_nms yield identical results"""
563
+ torch.random.manual_seed(seed)
564
565
num_boxes = 1000
566
iou_threshold = 0.9
0 commit comments