Skip to content

Commit ad9cc62

Browse files
authored
Add Quantized version of RoIAlign (#3624)
* WIP * clang * docs * extracted out common utils * Use better quantization function and pass tensors as parameters * proper dequantization * Some tests * Dequantization optimization, seems to gain a few ms * clang-format * again * more correct test. Had to remove optimization although it almost works * Also test aligned=True * remove useless part * more docs and comments * Put back optimization with more robust test * Added check for index upper bound * avoid possible overflow * Move common function into common.h * oops * scale=1,zero_point=0 makes more sense * Force batch size of 1 to prevent any indexingbug * format * format again * updated docstring * put back description comment for pre_calc_bilinear_interpolate * revert most changes to docstring as it's taken care of in another PR
1 parent 3a278d7 commit ad9cc62

File tree

5 files changed

+417
-117
lines changed

5 files changed

+417
-117
lines changed

test/test_ops.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,78 @@ def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwa
299299
for aligned in (True, False):
300300
super()._test_forward(device, contiguous, x_dtype, rois_dtype, aligned=aligned)
301301

302+
def test_qroialign(self):
303+
"""Make sure quantized version of RoIAlign is close to float version"""
304+
pool_size = 5
305+
img_size = 10
306+
n_channels = 2
307+
num_imgs = 1
308+
dtype = torch.float
309+
310+
def make_rois(num_rois=1000):
311+
rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype)
312+
rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) # set batch index
313+
rois[:, 3:] += rois[:, 1:3] # make sure boxes aren't degenerate
314+
return rois
315+
316+
for aligned in (True, False):
317+
for scale, zero_point in ((1, 0), (2, 10), (0.1, 50)):
318+
for qdtype in (torch.qint8, torch.quint8, torch.qint32):
319+
320+
x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype)
321+
qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qdtype)
322+
323+
rois = make_rois()
324+
qrois = torch.quantize_per_tensor(rois, scale=scale, zero_point=zero_point, dtype=qdtype)
325+
326+
x, rois = qx.dequantize(), qrois.dequantize() # we want to pass the same inputs
327+
328+
y = ops.roi_align(
329+
x,
330+
rois,
331+
output_size=pool_size,
332+
spatial_scale=1,
333+
sampling_ratio=-1,
334+
aligned=aligned,
335+
)
336+
qy = ops.roi_align(
337+
qx,
338+
qrois,
339+
output_size=pool_size,
340+
spatial_scale=1,
341+
sampling_ratio=-1,
342+
aligned=aligned,
343+
)
344+
345+
# The output qy is itself a quantized tensor and there might have been a loss of info when it was
346+
# quantized. For a fair comparison we need to quantize y as well
347+
quantized_float_y = torch.quantize_per_tensor(y, scale=scale, zero_point=zero_point, dtype=qdtype)
348+
349+
try:
350+
# Ideally, we would assert this, which passes with (scale, zero) == (1, 0)
351+
self.assertTrue((qy == quantized_float_y).all())
352+
except AssertionError:
353+
# But because the computation aren't exactly the same between the 2 RoIAlign procedures, some
354+
# rounding error may lead to a difference of 2 in the output.
355+
# For example with (scale, zero) = (2, 10), 45.00000... will be quantized to 44
356+
# but 45.00000001 will be rounded to 46. We make sure below that:
357+
# - such discrepancies between qy and quantized_float_y are very rare (less then 5%)
358+
# - any difference between qy and quantized_float_y is == scale
359+
diff_idx = torch.where(qy != quantized_float_y)
360+
num_diff = diff_idx[0].numel()
361+
self.assertTrue(num_diff / qy.numel() < .05)
362+
363+
abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize())
364+
t_scale = torch.full_like(abs_diff, fill_value=scale)
365+
self.assertTrue(torch.allclose(abs_diff, t_scale, atol=1e-5))
366+
367+
x = torch.randint(50, 100, size=(2, 3, 10, 10)).to(dtype)
368+
qx = torch.quantize_per_tensor(x, scale=1, zero_point=0, dtype=torch.qint8)
369+
rois = make_rois(10)
370+
qrois = torch.quantize_per_tensor(rois, scale=1, zero_point=0, dtype=torch.qint8)
371+
with self.assertRaisesRegex(RuntimeError, "Only one image per batch is allowed"):
372+
ops.roi_align(qx, qrois, output_size=pool_size)
373+
302374

303375
class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
304376
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
5+
namespace vision {
6+
namespace ops {
7+
namespace detail {
8+
9+
template <typename T>
10+
struct PreCalc {
11+
int pos1;
12+
int pos2;
13+
int pos3;
14+
int pos4;
15+
T w1;
16+
T w2;
17+
T w3;
18+
T w4;
19+
};
20+
21+
// This helper computes the interpolation weights (w1, w2...) for every sampling
22+
// point of a given box. There are pool_height * pool_width * roi_bin_grid_h *
23+
// roi_bin_grid_w such sampling points.
24+
//
25+
// The weights (w1, w2...) are computed as the areas in this figure:
26+
// https://en.wikipedia.org/wiki/Bilinear_interpolation#/media/File:Bilinear_interpolation_visualisation.svg
27+
// and pos1, pos2 etc correspond to the indices of their respective pixels.
28+
//
29+
// Note: the weights and indices are shared across all channels, which is why
30+
// they are pre-calculated prior to the main loop in the RoIAlign kernel.
31+
// implementation taken from Caffe2
32+
template <typename T>
33+
void pre_calc_for_bilinear_interpolate(
34+
int height,
35+
int width,
36+
int pooled_height,
37+
int pooled_width,
38+
T roi_start_h,
39+
T roi_start_w,
40+
T bin_size_h,
41+
T bin_size_w,
42+
int roi_bin_grid_h,
43+
int roi_bin_grid_w,
44+
std::vector<PreCalc<T>>& pre_calc) {
45+
int pre_calc_index = 0;
46+
for (int ph = 0; ph < pooled_height; ph++) {
47+
for (int pw = 0; pw < pooled_width; pw++) {
48+
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
49+
const T yy = roi_start_h + ph * bin_size_h +
50+
static_cast<T>(iy + .5f) * bin_size_h /
51+
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
52+
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
53+
const T xx = roi_start_w + pw * bin_size_w +
54+
static_cast<T>(ix + .5f) * bin_size_w /
55+
static_cast<T>(roi_bin_grid_w);
56+
57+
T x = xx;
58+
T y = yy;
59+
// deal with: inverse elements are out of feature map boundary
60+
if (y < -1.0 || y > height || x < -1.0 || x > width) {
61+
// empty
62+
PreCalc<T> pc;
63+
pc.pos1 = 0;
64+
pc.pos2 = 0;
65+
pc.pos3 = 0;
66+
pc.pos4 = 0;
67+
pc.w1 = 0;
68+
pc.w2 = 0;
69+
pc.w3 = 0;
70+
pc.w4 = 0;
71+
pre_calc[pre_calc_index] = pc;
72+
pre_calc_index += 1;
73+
continue;
74+
}
75+
76+
if (y <= 0) {
77+
y = 0;
78+
}
79+
if (x <= 0) {
80+
x = 0;
81+
}
82+
83+
int y_low = (int)y;
84+
int x_low = (int)x;
85+
int y_high;
86+
int x_high;
87+
88+
if (y_low >= height - 1) {
89+
y_high = y_low = height - 1;
90+
y = (T)y_low;
91+
} else {
92+
y_high = y_low + 1;
93+
}
94+
95+
if (x_low >= width - 1) {
96+
x_high = x_low = width - 1;
97+
x = (T)x_low;
98+
} else {
99+
x_high = x_low + 1;
100+
}
101+
102+
T ly = y - y_low;
103+
T lx = x - x_low;
104+
T hy = 1. - ly, hx = 1. - lx;
105+
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
106+
107+
// save weights and indices
108+
PreCalc<T> pc;
109+
pc.pos1 = y_low * width + x_low;
110+
pc.pos2 = y_low * width + x_high;
111+
pc.pos3 = y_high * width + x_low;
112+
pc.pos4 = y_high * width + x_high;
113+
pc.w1 = w1;
114+
pc.w2 = w2;
115+
pc.w3 = w3;
116+
pc.w4 = w4;
117+
pre_calc[pre_calc_index] = pc;
118+
119+
pre_calc_index += 1;
120+
}
121+
}
122+
}
123+
}
124+
}
125+
126+
} // namespace detail
127+
} // namespace ops
128+
} // namespace vision

torchvision/csrc/ops/cpu/roi_align_kernel.cpp

Lines changed: 8 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -1,120 +1,13 @@
11
#include <ATen/ATen.h>
22
#include <torch/library.h>
33

4+
#include "./roi_align_common.h"
5+
46
namespace vision {
57
namespace ops {
68

79
namespace {
810

9-
// implementation taken from Caffe2
10-
template <typename T>
11-
struct PreCalc {
12-
int pos1;
13-
int pos2;
14-
int pos3;
15-
int pos4;
16-
T w1;
17-
T w2;
18-
T w3;
19-
T w4;
20-
};
21-
22-
template <typename T>
23-
void pre_calc_for_bilinear_interpolate(
24-
int height,
25-
int width,
26-
int pooled_height,
27-
int pooled_width,
28-
int iy_upper,
29-
int ix_upper,
30-
T roi_start_h,
31-
T roi_start_w,
32-
T bin_size_h,
33-
T bin_size_w,
34-
int roi_bin_grid_h,
35-
int roi_bin_grid_w,
36-
std::vector<PreCalc<T>>& pre_calc) {
37-
int pre_calc_index = 0;
38-
for (int ph = 0; ph < pooled_height; ph++) {
39-
for (int pw = 0; pw < pooled_width; pw++) {
40-
for (int iy = 0; iy < iy_upper; iy++) {
41-
const T yy = roi_start_h + ph * bin_size_h +
42-
static_cast<T>(iy + .5f) * bin_size_h /
43-
static_cast<T>(roi_bin_grid_h); // e.g., 0.5, 1.5
44-
for (int ix = 0; ix < ix_upper; ix++) {
45-
const T xx = roi_start_w + pw * bin_size_w +
46-
static_cast<T>(ix + .5f) * bin_size_w /
47-
static_cast<T>(roi_bin_grid_w);
48-
49-
T x = xx;
50-
T y = yy;
51-
// deal with: inverse elements are out of feature map boundary
52-
if (y < -1.0 || y > height || x < -1.0 || x > width) {
53-
// empty
54-
PreCalc<T> pc;
55-
pc.pos1 = 0;
56-
pc.pos2 = 0;
57-
pc.pos3 = 0;
58-
pc.pos4 = 0;
59-
pc.w1 = 0;
60-
pc.w2 = 0;
61-
pc.w3 = 0;
62-
pc.w4 = 0;
63-
pre_calc[pre_calc_index] = pc;
64-
pre_calc_index += 1;
65-
continue;
66-
}
67-
68-
if (y <= 0) {
69-
y = 0;
70-
}
71-
if (x <= 0) {
72-
x = 0;
73-
}
74-
75-
int y_low = (int)y;
76-
int x_low = (int)x;
77-
int y_high;
78-
int x_high;
79-
80-
if (y_low >= height - 1) {
81-
y_high = y_low = height - 1;
82-
y = (T)y_low;
83-
} else {
84-
y_high = y_low + 1;
85-
}
86-
87-
if (x_low >= width - 1) {
88-
x_high = x_low = width - 1;
89-
x = (T)x_low;
90-
} else {
91-
x_high = x_low + 1;
92-
}
93-
94-
T ly = y - y_low;
95-
T lx = x - x_low;
96-
T hy = 1. - ly, hx = 1. - lx;
97-
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
98-
99-
// save weights and indeces
100-
PreCalc<T> pc;
101-
pc.pos1 = y_low * width + x_low;
102-
pc.pos2 = y_low * width + x_high;
103-
pc.pos3 = y_high * width + x_low;
104-
pc.pos4 = y_high * width + x_high;
105-
pc.w1 = w1;
106-
pc.w2 = w2;
107-
pc.w3 = w3;
108-
pc.w4 = w4;
109-
pre_calc[pre_calc_index] = pc;
110-
111-
pre_calc_index += 1;
112-
}
113-
}
114-
}
115-
}
116-
}
117-
11811
template <typename T>
11912
void roi_align_forward_kernel_impl(
12013
int n_rois,
@@ -167,17 +60,15 @@ void roi_align_forward_kernel_impl(
16760
// When the grid is empty, output zeros.
16861
const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4
16962

170-
// we want to precalculate indeces and weights shared by all chanels,
171-
// this is the key point of optimiation
172-
std::vector<PreCalc<T>> pre_calc(
63+
// we want to precalculate indices and weights shared by all chanels,
64+
// this is the key point of optimization
65+
std::vector<detail::PreCalc<T>> pre_calc(
17366
roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height);
174-
pre_calc_for_bilinear_interpolate(
67+
detail::pre_calc_for_bilinear_interpolate(
17568
height,
17669
width,
17770
pooled_height,
17871
pooled_width,
179-
roi_bin_grid_h,
180-
roi_bin_grid_w,
18172
roi_start_h,
18273
roi_start_w,
18374
bin_size_h,
@@ -199,15 +90,15 @@ void roi_align_forward_kernel_impl(
19990
T output_val = 0.;
20091
for (int iy = 0; iy < roi_bin_grid_h; iy++) {
20192
for (int ix = 0; ix < roi_bin_grid_w; ix++) {
202-
PreCalc<T> pc = pre_calc[pre_calc_index];
93+
detail::PreCalc<T> pc = pre_calc[pre_calc_index];
20394
output_val += pc.w1 * offset_input[pc.pos1] +
20495
pc.w2 * offset_input[pc.pos2] +
20596
pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4];
20697

20798
pre_calc_index += 1;
20899
}
209100
}
210-
output_val /= count;
101+
output_val /= count; // Average pooling
211102

212103
output[index] = output_val;
213104
} // for pw

0 commit comments

Comments
 (0)