Skip to content

Commit c786d12

Browse files
authored
Setting seeds for test_batched_nms_implementations (#4766)
* adding multiweight support for deeplabv3 prototype models * adding default values for optional params * fixing bug * addressing PR comment * fixing seed in test_batched_nms_implementations * change seeds in test_batched_nms_implementations * change seeds in test_batched_nms_implementations
1 parent 9528d96 commit c786d12

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

test/test_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,8 +557,10 @@ def test_nms_cuda_float16(self):
557557
keep16 = ops.nms(boxes.to(torch.float16), scores.to(torch.float16), iou_thres)
558558
assert_equal(keep32, keep16)
559559

560-
def test_batched_nms_implementations(self):
560+
@pytest.mark.parametrize("seed", range(10))
561+
def test_batched_nms_implementations(self, seed):
561562
"""Make sure that both implementations of batched_nms yield identical results"""
563+
torch.random.manual_seed(seed)
562564

563565
num_boxes = 1000
564566
iou_threshold = 0.9

0 commit comments

Comments
 (0)