Skip to content

Commit 48441cc

Browse files
zhiqwangfmassadatumbox
authored
Refactor grid default boxes with torch meshgrid (#3799)
* Refactor grid default boxes with torch.meshgrid * Fix torch jit tracing * Only doing the list multiplication once Co-authored-by: Francisco Massa <fvsmassa@gmail.com> * Make grid_default_box private as suggested Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com> * Replace list multiplication with torch.repeat * Move the clipping into _grid_default_boxes to accelerate Co-authored-by: Francisco Massa <fvsmassa@gmail.com> Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
1 parent 5dd7dfe commit 48441cc

File tree

1 file changed

+40
-25
lines changed

1 file changed

+40
-25
lines changed

torchvision/models/detection/anchor_utils.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -170,26 +170,59 @@ def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_
170170
else:
171171
self.scales = scales
172172

173-
self._wh_pairs = []
173+
self._wh_pairs = self._generate_wh_pairs(num_outputs)
174+
175+
def _generate_wh_pairs(self, num_outputs: int, dtype: torch.dtype = torch.float32,
176+
device: torch.device = torch.device("cpu")) -> List[Tensor]:
177+
_wh_pairs: List[Tensor] = []
174178
for k in range(num_outputs):
175179
# Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
176180
s_k = self.scales[k]
177181
s_prime_k = math.sqrt(self.scales[k] * self.scales[k + 1])
178-
wh_pairs = [(s_k, s_k), (s_prime_k, s_prime_k)]
182+
wh_pairs = [[s_k, s_k], [s_prime_k, s_prime_k]]
179183

180184
# Adding 2 pairs for each aspect ratio of the feature map k
181185
for ar in self.aspect_ratios[k]:
182186
sq_ar = math.sqrt(ar)
183187
w = self.scales[k] * sq_ar
184188
h = self.scales[k] / sq_ar
185-
wh_pairs.extend([(w, h), (h, w)])
189+
wh_pairs.extend([[w, h], [h, w]])
186190

187-
self._wh_pairs.append(wh_pairs)
191+
_wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
192+
return _wh_pairs
188193

189194
def num_anchors_per_location(self):
190195
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
191196
return [2 + 2 * len(r) for r in self.aspect_ratios]
192197

198+
# Default Boxes calculation based on page 6 of SSD paper
199+
def _grid_default_boxes(self, grid_sizes: List[List[int]], image_size: List[int],
200+
dtype: torch.dtype = torch.float32) -> Tensor:
201+
default_boxes = []
202+
for k, f_k in enumerate(grid_sizes):
203+
# Now add the default boxes for each width-height pair
204+
if self.steps is not None:
205+
x_f_k, y_f_k = [img_shape / self.steps[k] for img_shape in image_size]
206+
else:
207+
y_f_k, x_f_k = f_k
208+
209+
shifts_x = (torch.arange(0, f_k[1], dtype=dtype) + 0.5) / x_f_k
210+
shifts_y = (torch.arange(0, f_k[0], dtype=dtype) + 0.5) / y_f_k
211+
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
212+
shift_x = shift_x.reshape(-1)
213+
shift_y = shift_y.reshape(-1)
214+
215+
shifts = torch.stack((shift_x, shift_y) * len(self._wh_pairs[k]), dim=-1).reshape(-1, 2)
216+
# Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h)
217+
_wh_pair = self._wh_pairs[k].clamp(min=0, max=1) if self.clip else self._wh_pairs[k]
218+
wh_pairs = _wh_pair.repeat((f_k[0] * f_k[1]), 1)
219+
220+
default_box = torch.cat((shifts, wh_pairs), dim=1)
221+
222+
default_boxes.append(default_box)
223+
224+
return torch.cat(default_boxes, dim=0)
225+
193226
def __repr__(self) -> str:
194227
s = self.__class__.__name__ + '('
195228
s += 'aspect_ratios={aspect_ratios}'
@@ -203,30 +236,12 @@ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Ten
203236
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
204237
image_size = image_list.tensors.shape[-2:]
205238
dtype, device = feature_maps[0].dtype, feature_maps[0].device
206-
207-
# Default Boxes calculation based on page 6 of SSD paper
208-
default_boxes: List[List[float]] = []
209-
for k, f_k in enumerate(grid_sizes):
210-
# Now add the default boxes for each width-height pair
211-
for j in range(f_k[0]):
212-
if self.steps is not None:
213-
y_f_k = image_size[1] / self.steps[k]
214-
else:
215-
y_f_k = float(f_k[0])
216-
cy = (j + 0.5) / y_f_k
217-
for i in range(f_k[1]):
218-
if self.steps is not None:
219-
x_f_k = image_size[0] / self.steps[k]
220-
else:
221-
x_f_k = float(f_k[1])
222-
cx = (i + 0.5) / x_f_k
223-
default_boxes.extend([[cx, cy, w, h] for w, h in self._wh_pairs[k]])
239+
default_boxes = self._grid_default_boxes(grid_sizes, image_size, dtype=dtype)
240+
default_boxes = default_boxes.to(device)
224241

225242
dboxes = []
226243
for _ in image_list.image_sizes:
227-
dboxes_in_image = torch.tensor(default_boxes, dtype=dtype, device=device)
228-
if self.clip:
229-
dboxes_in_image.clamp_(min=0, max=1)
244+
dboxes_in_image = default_boxes
230245
dboxes_in_image = torch.cat([dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:],
231246
dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]], -1)
232247
dboxes_in_image[:, 0::2] *= image_size[1]

0 commit comments

Comments
 (0)