From 1f2ecd8c8a6d6aeb36226d306b164c97ba5cee82 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 24 Mar 2021 15:25:39 +0000 Subject: [PATCH 01/18] Add quantized version of nms --- setup.py | 7 +- .../csrc/ops/quantized/cpu/qnms_kernel.cpp | 132 ++++++++++++++++++ 2 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp diff --git a/setup.py b/setup.py index c998118335b..23bbdaab378 100644 --- a/setup.py +++ b/setup.py @@ -138,8 +138,11 @@ def get_extensions(): main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + glob.glob(os.path.join(extensions_dir, 'ops', '*.cpp')) - source_cpu = glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) + glob.glob( - os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp')) + source_cpu = ( + glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) + + glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp')) + + glob.glob(os.path.join(extensions_dir, 'ops', 'quantized', 'cpu', '*.cpp')) + ) is_rocm_pytorch = False if torch.__version__ >= '1.5': diff --git a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp new file mode 100644 index 00000000000..791b62b7317 --- /dev/null +++ b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp @@ -0,0 +1,132 @@ +#include +#include +#include + +namespace vision { +namespace ops { + +namespace { + +template +at::Tensor qnms_kernel_impl( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + + + TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor"); + TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor"); + TORCH_CHECK( + dets.scalar_type() == scores.scalar_type(), + "dets should have the same type as scores"); + + if (dets.numel() == 0) + return at::_empty_affine_quantized({0}, dets.options()); + + auto x1_t = dets.select(1, 0).contiguous(); + auto y1_t = dets.select(1, 1).contiguous(); + auto x2_t = dets.select(1, 2).contiguous(); + auto y2_t = dets.select(1, 3).contiguous(); + + // TODO: compute areas here, to avoid duplicated computation in the most inner loop + // at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); + + auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); + auto ndets = dets.size(0); + + at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); + at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); + + auto suppressed = suppressed_t.data_ptr(); + auto keep = keep_t.data_ptr(); + auto order = order_t.data_ptr(); + auto x1 = x1_t.data_ptr(); + auto y1 = y1_t.data_ptr(); + auto x2 = x2_t.data_ptr(); + auto y2 = y2_t.data_ptr(); + + const auto dets_scale = dets.q_scale(); + const auto dets_zero_point = dets.q_zero_point(); + + int64_t num_to_keep = 0; + + for (int64_t _i = 0; _i < ndets; _i++) { + auto i = order[_i]; + if (suppressed[i] == 1) + continue; + keep[num_to_keep++] = i; + auto ix1 = at::native::dequantize_val(dets_scale, dets_zero_point, x1[i]); + auto iy1 = at::native::dequantize_val(dets_scale, dets_zero_point, y1[i]); + auto ix2 = at::native::dequantize_val(dets_scale, dets_zero_point, x2[i]); + auto iy2 = at::native::dequantize_val(dets_scale, dets_zero_point, y2[i]); + auto iw = ix2 - ix1; + auto ih = iy2 - iy1; + auto iarea = iw * ih; + + for (int64_t _j = _i + 1; _j < ndets; _j++) { + auto j = order[_j]; + if (suppressed[j] == 1) + continue; + auto jx1 = at::native::dequantize_val(dets_scale, dets_zero_point, x1[j]); + auto jy1 = at::native::dequantize_val(dets_scale, dets_zero_point, y1[j]); + auto jx2 = at::native::dequantize_val(dets_scale, dets_zero_point, x2[j]); + auto jy2 = at::native::dequantize_val(dets_scale, dets_zero_point, y2[j]); + auto jw = jx2 - jx1; + auto jh = jy2 - jy1; + auto jarea = jw * jh; + + auto xx1 = std::max(ix1, jx1); + auto yy1 = std::max(iy1, jy1); + auto xx2 = std::min(ix2, jx2); + auto yy2 = std::min(iy2, jy2); + + auto w = std::max(0.f, xx2 - xx1); + auto h = std::max(0.f, yy2 - yy1); + auto inter = w * h; + auto ovr = inter / (iarea + jarea - inter); + if (ovr > iou_threshold) + suppressed[j] = 1; + } + } + return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); +} + +at::Tensor qnms_kernel( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + TORCH_CHECK( + dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK( + dets.size(1) == 4, + "boxes should have 4 elements in dimension 1, got ", + dets.size(1)); + TORCH_CHECK( + scores.dim() == 1, + "scores should be a 1d tensor, got ", + scores.dim(), + "D"); + TORCH_CHECK( + dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)); + + auto result = at::empty({0}); + + AT_DISPATCH_QINT_TYPES(dets.scalar_type(), "qnms_kernel", [&] { + result = qnms_kernel_impl(dets, scores, iou_threshold); + }); + return result; +} + +} // namespace + +TORCH_LIBRARY_IMPL(torchvision, QuantizedCPU, m) { + m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(qnms_kernel)); +} + +} // namespace ops +} // namespace vision From b63bac6171b1448e4adec8f73b302965489647b1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 24 Mar 2021 18:07:56 +0000 Subject: [PATCH 02/18] Added tests --- test/test_ops.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 8c938ae0e79..e3c52ad594f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -418,6 +418,27 @@ def test_nms(self): self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(3, 2), 0.5) self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(4), 0.5) + def test_qnms(self): + # Note: we compare qnms vs nms instead of qnms vs reference implementation. + # This is because with the int convertion, the trick used in _create_tensors_with_iou + # doesn't really work (in fact, nms vs reference implem will also fail with ints) + err_msg = 'NMS and QNMS give different results for IoU={}' + for iou in [0.2, 0.5, 0.8]: + boxes, scores = self._create_tensors_with_iou(1000, iou) + scores *= 100 # otherwise most scores would be 0 or 1 after int convertion + + # use integer values and clamp to the uint8 range for a fair comparison + boxes = boxes.to(torch.int).to(torch.float).clamp(0, 255) + scores = scores.to(torch.int).to(torch.float).clamp(0, 255) + + qboxes = torch.quantize_per_tensor(boxes, scale=1, zero_point=0, dtype=torch.quint8) + qscores = torch.quantize_per_tensor(scores, scale=1, zero_point=0, dtype=torch.quint8) + + keep = ops.nms(boxes, scores, iou) + qkeep = ops.nms(qboxes, qscores, iou) + + self.assertTrue(torch.allclose(qkeep, keep), err_msg.format(iou)) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_nms_cuda(self, dtype=torch.float64): tol = 1e-3 if dtype is torch.half else 1e-5 From 41a49274db4fea1817d5493a1205915df488b393 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 24 Mar 2021 18:09:15 +0000 Subject: [PATCH 03/18] Compute areas only once --- .../csrc/ops/quantized/cpu/qnms_kernel.cpp | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp index 791b62b7317..eb957cd1bb1 100644 --- a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp +++ b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp @@ -13,7 +13,6 @@ at::Tensor qnms_kernel_impl( const at::Tensor& scores, double iou_threshold) { - TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor"); TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor"); TORCH_CHECK( @@ -28,15 +27,14 @@ at::Tensor qnms_kernel_impl( auto x2_t = dets.select(1, 2).contiguous(); auto y2_t = dets.select(1, 3).contiguous(); - // TODO: compute areas here, to avoid duplicated computation in the most inner loop - // at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); - auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); auto ndets = dets.size(0); at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); + at::Tensor areas_t = at::zeros({ndets}, dets.options().dtype(at::kFloat)); + auto suppressed = suppressed_t.data_ptr(); auto keep = keep_t.data_ptr(); auto order = order_t.data_ptr(); @@ -44,10 +42,16 @@ at::Tensor qnms_kernel_impl( auto y1 = y1_t.data_ptr(); auto x2 = x2_t.data_ptr(); auto y2 = y2_t.data_ptr(); + auto areas = areas_t.data_ptr(); + auto areas_a = areas_t.accessor(); const auto dets_scale = dets.q_scale(); const auto dets_zero_point = dets.q_zero_point(); + for (int64_t i = 0; i < ndets; i++) { + areas_a[i] = dets_scale**2 * (x2[i].val_ - x1[i].val_) * (y2[i].val_ - y1[i].val_); + } + int64_t num_to_keep = 0; for (int64_t _i = 0; _i < ndets; _i++) { @@ -59,9 +63,7 @@ at::Tensor qnms_kernel_impl( auto iy1 = at::native::dequantize_val(dets_scale, dets_zero_point, y1[i]); auto ix2 = at::native::dequantize_val(dets_scale, dets_zero_point, x2[i]); auto iy2 = at::native::dequantize_val(dets_scale, dets_zero_point, y2[i]); - auto iw = ix2 - ix1; - auto ih = iy2 - iy1; - auto iarea = iw * ih; + auto iarea = areas[i]; for (int64_t _j = _i + 1; _j < ndets; _j++) { auto j = order[_j]; @@ -71,10 +73,6 @@ at::Tensor qnms_kernel_impl( auto jy1 = at::native::dequantize_val(dets_scale, dets_zero_point, y1[j]); auto jx2 = at::native::dequantize_val(dets_scale, dets_zero_point, x2[j]); auto jy2 = at::native::dequantize_val(dets_scale, dets_zero_point, y2[j]); - auto jw = jx2 - jx1; - auto jh = jy2 - jy1; - auto jarea = jw * jh; - auto xx1 = std::max(ix1, jx1); auto yy1 = std::max(iy1, jy1); auto xx2 = std::min(ix2, jx2); @@ -83,7 +81,7 @@ at::Tensor qnms_kernel_impl( auto w = std::max(0.f, xx2 - xx1); auto h = std::max(0.f, yy2 - yy1); auto inter = w * h; - auto ovr = inter / (iarea + jarea - inter); + auto ovr = inter / (iarea + areas[j] - inter); if (ovr > iou_threshold) suppressed[j] = 1; } From cdb44fc2263806f892796c186389b8cf9c4433ad Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 24 Mar 2021 18:42:34 +0000 Subject: [PATCH 04/18] remove calls to dequantize_val --- .../csrc/ops/quantized/cpu/qnms_kernel.cpp | 42 ++++++++----------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp index eb957cd1bb1..c313024090a 100644 --- a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp +++ b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp @@ -22,17 +22,17 @@ at::Tensor qnms_kernel_impl( if (dets.numel() == 0) return at::_empty_affine_quantized({0}, dets.options()); + const auto ndets = dets.size(0); + const double dets_scale = dets.q_scale(); + + auto x1_t = dets.select(1, 0).contiguous(); auto y1_t = dets.select(1, 1).contiguous(); auto x2_t = dets.select(1, 2).contiguous(); auto y2_t = dets.select(1, 3).contiguous(); - auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); - auto ndets = dets.size(0); - at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); - at::Tensor areas_t = at::zeros({ndets}, dets.options().dtype(at::kFloat)); auto suppressed = suppressed_t.data_ptr(); @@ -43,13 +43,10 @@ at::Tensor qnms_kernel_impl( auto x2 = x2_t.data_ptr(); auto y2 = y2_t.data_ptr(); auto areas = areas_t.data_ptr(); - auto areas_a = areas_t.accessor(); - - const auto dets_scale = dets.q_scale(); - const auto dets_zero_point = dets.q_zero_point(); + auto areas_a = areas_t.accessor(); for (int64_t i = 0; i < ndets; i++) { - areas_a[i] = dets_scale**2 * (x2[i].val_ - x1[i].val_) * (y2[i].val_ - y1[i].val_); + areas_a[i] = dets_scale * dets_scale * (x2[i].val_ - x1[i].val_) * (y2[i].val_ - y1[i].val_); } int64_t num_to_keep = 0; @@ -59,27 +56,24 @@ at::Tensor qnms_kernel_impl( if (suppressed[i] == 1) continue; keep[num_to_keep++] = i; - auto ix1 = at::native::dequantize_val(dets_scale, dets_zero_point, x1[i]); - auto iy1 = at::native::dequantize_val(dets_scale, dets_zero_point, y1[i]); - auto ix2 = at::native::dequantize_val(dets_scale, dets_zero_point, x2[i]); - auto iy2 = at::native::dequantize_val(dets_scale, dets_zero_point, y2[i]); + + auto ix1val = x1[i].val_; + auto iy1val = y1[i].val_; + auto ix2val = x2[i].val_; + auto iy2val = y2[i].val_; auto iarea = areas[i]; for (int64_t _j = _i + 1; _j < ndets; _j++) { auto j = order[_j]; if (suppressed[j] == 1) continue; - auto jx1 = at::native::dequantize_val(dets_scale, dets_zero_point, x1[j]); - auto jy1 = at::native::dequantize_val(dets_scale, dets_zero_point, y1[j]); - auto jx2 = at::native::dequantize_val(dets_scale, dets_zero_point, x2[j]); - auto jy2 = at::native::dequantize_val(dets_scale, dets_zero_point, y2[j]); - auto xx1 = std::max(ix1, jx1); - auto yy1 = std::max(iy1, jy1); - auto xx2 = std::min(ix2, jx2); - auto yy2 = std::min(iy2, jy2); - - auto w = std::max(0.f, xx2 - xx1); - auto h = std::max(0.f, yy2 - yy1); + auto xx1 = std::max(ix1val, x1[j].val_); + auto yy1 = std::max(iy1val, y1[j].val_); + auto xx2 = std::min(ix2val, x2[j].val_); + auto yy2 = std::min(iy2val, y2[j].val_); + + float w = dets_scale * std::max(0, xx2 - xx1); + float h = dets_scale * std::max(0, yy2 - yy1); auto inter = w * h; auto ovr = inter / (iarea + areas[j] - inter); if (ovr > iou_threshold) From fb451cc2ab487ef21c0a08dafab34f86d3894881 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 24 Mar 2021 18:48:11 +0000 Subject: [PATCH 05/18] fix return type for empty tensor --- torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp index c313024090a..6029e448fcb 100644 --- a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp +++ b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp @@ -20,7 +20,7 @@ at::Tensor qnms_kernel_impl( "dets should have the same type as scores"); if (dets.numel() == 0) - return at::_empty_affine_quantized({0}, dets.options()); + return at::empty({0}, dets.options().dtype(at::kLong)); const auto ndets = dets.size(0); const double dets_scale = dets.q_scale(); From 2eef613e7fffbbf0ac63837ffbfa0de6cb4ebf24 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 24 Mar 2021 18:53:48 +0000 Subject: [PATCH 06/18] flake8 --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index e3c52ad594f..a20e91d96e9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -425,7 +425,7 @@ def test_qnms(self): err_msg = 'NMS and QNMS give different results for IoU={}' for iou in [0.2, 0.5, 0.8]: boxes, scores = self._create_tensors_with_iou(1000, iou) - scores *= 100 # otherwise most scores would be 0 or 1 after int convertion + scores *= 100 # otherwise most scores would be 0 or 1 after int convertion # use integer values and clamp to the uint8 range for a fair comparison boxes = boxes.to(torch.int).to(torch.float).clamp(0, 255) From f8a56d7ba3cee2bdfa2c26f33388be298e4f3ea3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 25 Mar 2021 10:37:52 +0000 Subject: [PATCH 07/18] remove use of scale as it gets cancelled out --- torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp index 6029e448fcb..96136c53b69 100644 --- a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp +++ b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp @@ -23,8 +23,6 @@ at::Tensor qnms_kernel_impl( return at::empty({0}, dets.options().dtype(at::kLong)); const auto ndets = dets.size(0); - const double dets_scale = dets.q_scale(); - auto x1_t = dets.select(1, 0).contiguous(); auto y1_t = dets.select(1, 1).contiguous(); @@ -46,7 +44,10 @@ at::Tensor qnms_kernel_impl( auto areas_a = areas_t.accessor(); for (int64_t i = 0; i < ndets; i++) { - areas_a[i] = dets_scale * dets_scale * (x2[i].val_ - x1[i].val_) * (y2[i].val_ - y1[i].val_); + // Note: To get the exact area we'd need to multiply by scale**2, but this + // would get canceled out in the computation of ovr below. + // So we leave that out. + areas_a[i] = (x2[i].val_ - x1[i].val_) * (y2[i].val_ - y1[i].val_); } int64_t num_to_keep = 0; @@ -72,8 +73,8 @@ at::Tensor qnms_kernel_impl( auto xx2 = std::min(ix2val, x2[j].val_); auto yy2 = std::min(iy2val, y2[j].val_); - float w = dets_scale * std::max(0, xx2 - xx1); - float h = dets_scale * std::max(0, yy2 - yy1); + auto w = std::max(0, xx2 - xx1); // * scale (gets canceled below) + auto h = std::max(0, yy2 - yy1); // * scale (gets canceled below) auto inter = w * h; auto ovr = inter / (iarea + areas[j] - inter); if (ovr > iou_threshold) From 4c75c568e1fbe9a91d2c8be6149c071a52074a8b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 25 Mar 2021 10:38:01 +0000 Subject: [PATCH 08/18] simpler int convertion in tests --- test/test_ops.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index a20e91d96e9..1f8acbea8dd 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -427,13 +427,12 @@ def test_qnms(self): boxes, scores = self._create_tensors_with_iou(1000, iou) scores *= 100 # otherwise most scores would be 0 or 1 after int convertion - # use integer values and clamp to the uint8 range for a fair comparison - boxes = boxes.to(torch.int).to(torch.float).clamp(0, 255) - scores = scores.to(torch.int).to(torch.float).clamp(0, 255) - qboxes = torch.quantize_per_tensor(boxes, scale=1, zero_point=0, dtype=torch.quint8) qscores = torch.quantize_per_tensor(scores, scale=1, zero_point=0, dtype=torch.quint8) + boxes = qboxes.dequantize() + scores = qscores.dequantize() + keep = ops.nms(boxes, scores, iou) qkeep = ops.nms(qboxes, qscores, iou) From 2ac53132edbbfc09bf1d2bf2ecc31eb1c4a76d2e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 25 Mar 2021 10:41:19 +0000 Subject: [PATCH 09/18] explicitly set ovr to double --- torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp index 96136c53b69..f428793a82c 100644 --- a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp +++ b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp @@ -76,7 +76,7 @@ at::Tensor qnms_kernel_impl( auto w = std::max(0, xx2 - xx1); // * scale (gets canceled below) auto h = std::max(0, yy2 - yy1); // * scale (gets canceled below) auto inter = w * h; - auto ovr = inter / (iarea + areas[j] - inter); + double ovr = inter / (iarea + areas[j] - inter); if (ovr > iou_threshold) suppressed[j] = 1; } From f47d238ba0d1b571da694f657f04a9cb57eb101d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 25 Mar 2021 10:45:14 +0000 Subject: [PATCH 10/18] add tests for more values of scale and zero_point --- test/test_ops.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 1f8acbea8dd..0031da45cce 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -424,19 +424,22 @@ def test_qnms(self): # doesn't really work (in fact, nms vs reference implem will also fail with ints) err_msg = 'NMS and QNMS give different results for IoU={}' for iou in [0.2, 0.5, 0.8]: - boxes, scores = self._create_tensors_with_iou(1000, iou) - scores *= 100 # otherwise most scores would be 0 or 1 after int convertion + for scale, zero_point in ((1, 0), (2, 50), (3, 10)): + boxes, scores = self._create_tensors_with_iou(1000, iou) + scores *= 100 # otherwise most scores would be 0 or 1 after int convertion - qboxes = torch.quantize_per_tensor(boxes, scale=1, zero_point=0, dtype=torch.quint8) - qscores = torch.quantize_per_tensor(scores, scale=1, zero_point=0, dtype=torch.quint8) + qboxes = torch.quantize_per_tensor(boxes, scale=scale, zero_point=zero_point, + dtype=torch.quint8) + qscores = torch.quantize_per_tensor(scores, scale=scale, zero_point=zero_point, + dtype=torch.quint8) - boxes = qboxes.dequantize() - scores = qscores.dequantize() + boxes = qboxes.dequantize() + scores = qscores.dequantize() - keep = ops.nms(boxes, scores, iou) - qkeep = ops.nms(qboxes, qscores, iou) + keep = ops.nms(boxes, scores, iou) + qkeep = ops.nms(qboxes, qscores, iou) - self.assertTrue(torch.allclose(qkeep, keep), err_msg.format(iou)) + self.assertTrue(torch.allclose(qkeep, keep), err_msg.format(iou)) @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_nms_cuda(self, dtype=torch.float64): From 4b31259cf1e3d81caf242f54e47d93386b5c77cc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 25 Mar 2021 14:39:32 +0000 Subject: [PATCH 11/18] comment about underflow --- torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp index f428793a82c..7450a586101 100644 --- a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp +++ b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp @@ -44,9 +44,11 @@ at::Tensor qnms_kernel_impl( auto areas_a = areas_t.accessor(); for (int64_t i = 0; i < ndets; i++) { - // Note: To get the exact area we'd need to multiply by scale**2, but this - // would get canceled out in the computation of ovr below. - // So we leave that out. + // Note 1: To get the exact area we'd need to multiply by scale**2, but this + // would get canceled out in the computation of ovr below. So we leave that + // out. + // Note 2: degenerate boxes (x2 < x1 or y2 < y1) may underflow. Same below + // when computing w and h areas_a[i] = (x2[i].val_ - x1[i].val_) * (y2[i].val_ - y1[i].val_); } From 7c18714ffafc777b7ec5d9cf508555717c7c103a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 25 Mar 2021 14:55:48 +0000 Subject: [PATCH 12/18] remove unnecessary accessor --- torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp index 7450a586101..d1762c81470 100644 --- a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp +++ b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp @@ -42,14 +42,13 @@ at::Tensor qnms_kernel_impl( auto y2 = y2_t.data_ptr(); auto areas = areas_t.data_ptr(); - auto areas_a = areas_t.accessor(); for (int64_t i = 0; i < ndets; i++) { // Note 1: To get the exact area we'd need to multiply by scale**2, but this // would get canceled out in the computation of ovr below. So we leave that // out. // Note 2: degenerate boxes (x2 < x1 or y2 < y1) may underflow. Same below // when computing w and h - areas_a[i] = (x2[i].val_ - x1[i].val_) * (y2[i].val_ - y1[i].val_); + areas[i] = (x2[i].val_ - x1[i].val_) * (y2[i].val_ - y1[i].val_); } int64_t num_to_keep = 0; From 31d76ce22bcecf95e7040e44e6e14d3cfbd4b9c5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 25 Mar 2021 16:04:29 +0000 Subject: [PATCH 13/18] properly convert to float for division --- torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp index d1762c81470..4117f880bd8 100644 --- a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp +++ b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp @@ -77,7 +77,7 @@ at::Tensor qnms_kernel_impl( auto w = std::max(0, xx2 - xx1); // * scale (gets canceled below) auto h = std::max(0, yy2 - yy1); // * scale (gets canceled below) auto inter = w * h; - double ovr = inter / (iarea + areas[j] - inter); + auto ovr = (float)inter / (iarea + areas[j] - inter); if (ovr > iou_threshold) suppressed[j] = 1; } From 618bbe1bd1a46b9e57b9f61aae78a76121471112 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 29 Mar 2021 14:32:22 +0100 Subject: [PATCH 14/18] Add comments about underflow --- torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp index 4117f880bd8..76a9fb1a559 100644 --- a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp +++ b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp @@ -46,8 +46,10 @@ at::Tensor qnms_kernel_impl( // Note 1: To get the exact area we'd need to multiply by scale**2, but this // would get canceled out in the computation of ovr below. So we leave that // out. - // Note 2: degenerate boxes (x2 < x1 or y2 < y1) may underflow. Same below - // when computing w and h + // Note 2: degenerate boxes (x2 < x1 or y2 < y1) may underflow, although + // integral promotion rules will likely prevent it (see + // https://stackoverflow.com/questions/32959564/subtraction-of-two-unsigned-gives-signed + // for more details). areas[i] = (x2[i].val_ - x1[i].val_) * (y2[i].val_ - y1[i].val_); } @@ -74,6 +76,11 @@ at::Tensor qnms_kernel_impl( auto xx2 = std::min(ix2val, x2[j].val_); auto yy2 = std::min(iy2val, y2[j].val_); + // This may underflow if xx2 < xx1 on unsigned types but as noted above, + // integral promotion should prevent it. Also, an actual underflow would + // lead to a negative ovr (because of high value for inter), but since the + // actual over should have been 0 the condition below isn't altered, and + // thus the underflow should be effectively harmless. auto w = std::max(0, xx2 - xx1); // * scale (gets canceled below) auto h = std::max(0, yy2 - yy1); // * scale (gets canceled below) auto inter = w * h; From 7e2733751c3b642eb37b77b4cb1a243b8720e060 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 29 Mar 2021 14:53:08 +0100 Subject: [PATCH 15/18] explicitely cast coordinates to float to allow vectorization --- .../csrc/ops/quantized/cpu/qnms_kernel.cpp | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp index 76a9fb1a559..03427c9191b 100644 --- a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp +++ b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp @@ -61,30 +61,26 @@ at::Tensor qnms_kernel_impl( continue; keep[num_to_keep++] = i; - auto ix1val = x1[i].val_; - auto iy1val = y1[i].val_; - auto ix2val = x2[i].val_; - auto iy2val = y2[i].val_; - auto iarea = areas[i]; + // We explicitely cast coordinates to float so that the code can be vectorized. + float ix1val = x1[i].val_; + float iy1val = y1[i].val_; + float ix2val = x2[i].val_; + float iy2val = y2[i].val_; + float iarea = areas[i]; for (int64_t _j = _i + 1; _j < ndets; _j++) { auto j = order[_j]; if (suppressed[j] == 1) continue; - auto xx1 = std::max(ix1val, x1[j].val_); - auto yy1 = std::max(iy1val, y1[j].val_); - auto xx2 = std::min(ix2val, x2[j].val_); - auto yy2 = std::min(iy2val, y2[j].val_); - - // This may underflow if xx2 < xx1 on unsigned types but as noted above, - // integral promotion should prevent it. Also, an actual underflow would - // lead to a negative ovr (because of high value for inter), but since the - // actual over should have been 0 the condition below isn't altered, and - // thus the underflow should be effectively harmless. - auto w = std::max(0, xx2 - xx1); // * scale (gets canceled below) - auto h = std::max(0, yy2 - yy1); // * scale (gets canceled below) + float xx1 = std::max(ix1val, (float)x1[j].val_); + float yy1 = std::max(iy1val, (float)y1[j].val_); + float xx2 = std::min(ix2val, (float)x2[j].val_); + float yy2 = std::min(iy2val, (float)y2[j].val_); + + auto w = std::max(0.f, xx2 - xx1); // * scale (gets canceled below) + auto h = std::max(0.f, yy2 - yy1); // * scale (gets canceled below) auto inter = w * h; - auto ovr = (float)inter / (iarea + areas[j] - inter); + auto ovr = inter / (iarea + areas[j] - inter); if (ovr > iou_threshold) suppressed[j] = 1; } From 4cda6e56f74cb824f99bfe6730f08eeafdf2b69d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 29 Mar 2021 15:28:02 +0100 Subject: [PATCH 16/18] clang --- .../csrc/ops/quantized/cpu/qnms_kernel.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp index 03427c9191b..7dd75e9c35b 100644 --- a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp +++ b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp @@ -61,7 +61,8 @@ at::Tensor qnms_kernel_impl( continue; keep[num_to_keep++] = i; - // We explicitely cast coordinates to float so that the code can be vectorized. + // We explicitely cast coordinates to float so that the code can be + // vectorized. float ix1val = x1[i].val_; float iy1val = y1[i].val_; float ix2val = x2[i].val_; @@ -77,8 +78,8 @@ at::Tensor qnms_kernel_impl( float xx2 = std::min(ix2val, (float)x2[j].val_); float yy2 = std::min(iy2val, (float)y2[j].val_); - auto w = std::max(0.f, xx2 - xx1); // * scale (gets canceled below) - auto h = std::max(0.f, yy2 - yy1); // * scale (gets canceled below) + auto w = std::max(0.f, xx2 - xx1); // * scale (gets canceled below) + auto h = std::max(0.f, yy2 - yy1); // * scale (gets canceled below) auto inter = w * h; auto ovr = inter / (iarea + areas[j] - inter); if (ovr > iou_threshold) @@ -106,10 +107,10 @@ at::Tensor qnms_kernel( TORCH_CHECK( dets.size(0) == scores.size(0), "boxes and scores should have same number of elements in ", - "dimension 0, got ", - dets.size(0), - " and ", - scores.size(0)); + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)); auto result = at::empty({0}); From dd463768cd09ff363eb3e9b9598f1f4cc6ad2513 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 29 Mar 2021 15:31:02 +0100 Subject: [PATCH 17/18] clang again --- torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp index 7dd75e9c35b..1364d823fdd 100644 --- a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp +++ b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp @@ -20,7 +20,7 @@ at::Tensor qnms_kernel_impl( "dets should have the same type as scores"); if (dets.numel() == 0) - return at::empty({0}, dets.options().dtype(at::kLong)); + return at::empty({0}, dets.options().dtype(at::kLong)); const auto ndets = dets.size(0); From aef04fdb5dddb0653c4ff41e97746cdc93fd4846 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 29 Mar 2021 15:44:48 +0100 Subject: [PATCH 18/18] hopefully OK now --- torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp index 1364d823fdd..f7b081327b2 100644 --- a/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp +++ b/torchvision/csrc/ops/quantized/cpu/qnms_kernel.cpp @@ -12,7 +12,6 @@ at::Tensor qnms_kernel_impl( const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { - TORCH_CHECK(!dets.is_cuda(), "dets must be a CPU tensor"); TORCH_CHECK(!scores.is_cuda(), "scores must be a CPU tensor"); TORCH_CHECK(