@@ -170,26 +170,59 @@ def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_
170
170
else :
171
171
self .scales = scales
172
172
173
- self ._wh_pairs = []
173
+ self ._wh_pairs = self ._generate_wh_pairs (num_outputs )
174
+
175
+ def _generate_wh_pairs (self , num_outputs : int , dtype : torch .dtype = torch .float32 ,
176
+ device : torch .device = torch .device ("cpu" )) -> List [Tensor ]:
177
+ _wh_pairs : List [Tensor ] = []
174
178
for k in range (num_outputs ):
175
179
# Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
176
180
s_k = self .scales [k ]
177
181
s_prime_k = math .sqrt (self .scales [k ] * self .scales [k + 1 ])
178
- wh_pairs = [( s_k , s_k ), ( s_prime_k , s_prime_k ) ]
182
+ wh_pairs = [[ s_k , s_k ], [ s_prime_k , s_prime_k ] ]
179
183
180
184
# Adding 2 pairs for each aspect ratio of the feature map k
181
185
for ar in self .aspect_ratios [k ]:
182
186
sq_ar = math .sqrt (ar )
183
187
w = self .scales [k ] * sq_ar
184
188
h = self .scales [k ] / sq_ar
185
- wh_pairs .extend ([( w , h ), ( h , w ) ])
189
+ wh_pairs .extend ([[ w , h ], [ h , w ] ])
186
190
187
- self ._wh_pairs .append (wh_pairs )
191
+ _wh_pairs .append (torch .as_tensor (wh_pairs , dtype = dtype , device = device ))
192
+ return _wh_pairs
188
193
189
194
def num_anchors_per_location (self ):
190
195
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
191
196
return [2 + 2 * len (r ) for r in self .aspect_ratios ]
192
197
198
+ # Default Boxes calculation based on page 6 of SSD paper
199
+ def _grid_default_boxes (self , grid_sizes : List [List [int ]], image_size : List [int ],
200
+ dtype : torch .dtype = torch .float32 ) -> Tensor :
201
+ default_boxes = []
202
+ for k , f_k in enumerate (grid_sizes ):
203
+ # Now add the default boxes for each width-height pair
204
+ if self .steps is not None :
205
+ x_f_k , y_f_k = [img_shape / self .steps [k ] for img_shape in image_size ]
206
+ else :
207
+ y_f_k , x_f_k = f_k
208
+
209
+ shifts_x = (torch .arange (0 , f_k [1 ], dtype = dtype ) + 0.5 ) / x_f_k
210
+ shifts_y = (torch .arange (0 , f_k [0 ], dtype = dtype ) + 0.5 ) / y_f_k
211
+ shift_y , shift_x = torch .meshgrid (shifts_y , shifts_x )
212
+ shift_x = shift_x .reshape (- 1 )
213
+ shift_y = shift_y .reshape (- 1 )
214
+
215
+ shifts = torch .stack ((shift_x , shift_y ) * len (self ._wh_pairs [k ]), dim = - 1 ).reshape (- 1 , 2 )
216
+ # Clipping the default boxes while the boxes are encoded in format (cx, cy, w, h)
217
+ _wh_pair = self ._wh_pairs [k ].clamp (min = 0 , max = 1 ) if self .clip else self ._wh_pairs [k ]
218
+ wh_pairs = _wh_pair .repeat ((f_k [0 ] * f_k [1 ]), 1 )
219
+
220
+ default_box = torch .cat ((shifts , wh_pairs ), dim = 1 )
221
+
222
+ default_boxes .append (default_box )
223
+
224
+ return torch .cat (default_boxes , dim = 0 )
225
+
193
226
def __repr__ (self ) -> str :
194
227
s = self .__class__ .__name__ + '('
195
228
s += 'aspect_ratios={aspect_ratios}'
@@ -203,30 +236,12 @@ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Ten
203
236
grid_sizes = [feature_map .shape [- 2 :] for feature_map in feature_maps ]
204
237
image_size = image_list .tensors .shape [- 2 :]
205
238
dtype , device = feature_maps [0 ].dtype , feature_maps [0 ].device
206
-
207
- # Default Boxes calculation based on page 6 of SSD paper
208
- default_boxes : List [List [float ]] = []
209
- for k , f_k in enumerate (grid_sizes ):
210
- # Now add the default boxes for each width-height pair
211
- for j in range (f_k [0 ]):
212
- if self .steps is not None :
213
- y_f_k = image_size [1 ] / self .steps [k ]
214
- else :
215
- y_f_k = float (f_k [0 ])
216
- cy = (j + 0.5 ) / y_f_k
217
- for i in range (f_k [1 ]):
218
- if self .steps is not None :
219
- x_f_k = image_size [0 ] / self .steps [k ]
220
- else :
221
- x_f_k = float (f_k [1 ])
222
- cx = (i + 0.5 ) / x_f_k
223
- default_boxes .extend ([[cx , cy , w , h ] for w , h in self ._wh_pairs [k ]])
239
+ default_boxes = self ._grid_default_boxes (grid_sizes , image_size , dtype = dtype )
240
+ default_boxes = default_boxes .to (device )
224
241
225
242
dboxes = []
226
243
for _ in image_list .image_sizes :
227
- dboxes_in_image = torch .tensor (default_boxes , dtype = dtype , device = device )
228
- if self .clip :
229
- dboxes_in_image .clamp_ (min = 0 , max = 1 )
244
+ dboxes_in_image = default_boxes
230
245
dboxes_in_image = torch .cat ([dboxes_in_image [:, :2 ] - 0.5 * dboxes_in_image [:, 2 :],
231
246
dboxes_in_image [:, :2 ] + 0.5 * dboxes_in_image [:, 2 :]], - 1 )
232
247
dboxes_in_image [:, 0 ::2 ] *= image_size [1 ]
0 commit comments