Skip to content

Commit d3f345e

Browse files
committed
Implement forward methods + temp workarounds to inherit from retina.
1 parent 34237e4 commit d3f345e

File tree

3 files changed

+74
-27
lines changed

3 files changed

+74
-27
lines changed

torchvision/models/detection/retinanet.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,15 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes):
454454

455455
return detections
456456

457+
def _anchors_per_level(self, features, HWA):
458+
# recover level sizes
459+
num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
460+
HW = 0
461+
for v in num_anchors_per_level:
462+
HW += v
463+
A = HWA // HW
464+
return [hw * A for hw in num_anchors_per_level]
465+
457466
def forward(self, images, targets=None):
458467
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
459468
"""
@@ -531,13 +540,7 @@ def forward(self, images, targets=None):
531540
losses = self.compute_loss(targets, head_outputs, anchors)
532541
else:
533542
# recover level sizes
534-
num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
535-
HW = 0
536-
for v in num_anchors_per_level:
537-
HW += v
538-
HWA = head_outputs['cls_logits'].size(1)
539-
A = HWA // HW
540-
num_anchors_per_level = [hw * A for hw in num_anchors_per_level]
543+
num_anchors_per_level = self._anchors_per_level(features, head_outputs['cls_logits'].size(1))
541544

542545
# split outputs per level
543546
split_head_outputs: Dict[str, List[Tensor]] = {}

torchvision/models/detection/ssd.py

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,35 +24,68 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes:
2424
self.regression_head = SSDRegressionHead(in_channels, num_anchors)
2525

2626

27-
class SSDClassificationHead(nn.Module):
28-
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int):
27+
class SSDScoringHead(nn.Module):
28+
def __init__(self, module_list: nn.ModuleList, num_columns: int):
2929
super().__init__()
30-
self.cls_logits = nn.ModuleList()
30+
self.module_list = module_list
31+
self.num_columns = num_columns
32+
33+
def get_result_from_module_list(self, x: Tensor, idx: int) -> Tensor:
34+
"""
35+
This is equivalent to self.module_list[idx](x),
36+
but torchscript doesn't support this yet
37+
"""
38+
num_blocks = len(self.module_list)
39+
if idx < 0:
40+
idx += num_blocks
41+
i = 0
42+
out = x
43+
for module in self.module_list:
44+
if i == idx:
45+
out = module(x)
46+
i += 1
47+
return out
48+
49+
def forward(self, x: List[Tensor]) -> Tensor:
50+
all_results = []
51+
52+
for i, features in enumerate(x):
53+
results = self.get_result_from_module_list(features, i)
54+
55+
# Permute output from (N, A * K, H, W) to (N, HWA, K).
56+
N, _, H, W = results.shape
57+
results = results.view(N, -1, self.num_columns, H, W)
58+
results = results.permute(0, 3, 4, 1, 2)
59+
results = results.reshape(N, -1, self.num_columns) # Size=(N, HWA, K)
60+
61+
all_results.append(results)
62+
63+
return torch.cat(all_results, dim=1)
64+
65+
66+
class SSDClassificationHead(SSDScoringHead):
67+
def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int):
68+
cls_logits = nn.ModuleList()
3169
for channels, anchors in zip(in_channels, num_anchors):
32-
self.cls_logits.append(nn.Conv2d(channels, num_classes * anchors, kernel_size=3, padding=1))
70+
cls_logits.append(nn.Conv2d(channels, num_classes * anchors, kernel_size=3, padding=1))
71+
super().__init__(cls_logits, num_classes)
3372

3473
def compute_loss(self, targets: List[Dict[str, Tensor]], head_outputs: Dict[str, Tensor],
3574
matched_idxs: List[Tensor]) -> Tensor:
3675
pass
3776

38-
def forward(self, x: List[Tensor]) -> Tensor:
39-
pass
4077

41-
42-
class SSDRegressionHead(nn.Module):
78+
class SSDRegressionHead(SSDScoringHead):
4379
def __init__(self, in_channels: List[int], num_anchors: List[int]):
44-
super().__init__()
45-
self.bbox_reg = nn.ModuleList()
80+
bbox_reg = nn.ModuleList()
4681
for channels, anchors in zip(in_channels, num_anchors):
47-
self.bbox_reg.append(nn.Conv2d(channels, 4 * anchors, kernel_size=3, padding=1))
82+
bbox_reg.append(nn.Conv2d(channels, 4 * anchors, kernel_size=3, padding=1))
83+
super().__init__(bbox_reg, 4)
4884

4985
def compute_loss(self, targets: List[Dict[str, Tensor]], head_outputs: Dict[str, Tensor], anchors: List[Tensor],
5086
matched_idxs: List[Tensor]) -> Tensor:
5187
pass
5288

53-
def forward(self, x: List[Tensor]) -> Tensor:
54-
pass
55-
5689

5790
class SSD(RetinaNet):
5891
def __init__(self, backbone: nn.Module, num_classes: int,
@@ -80,8 +113,8 @@ def __init__(self, backbone: nn.Module, num_classes: int,
80113
self.backbone = backbone
81114

82115
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
83-
num_anchors = [2 + 2 * len(r) for r in aspect_ratios]
84-
self.head = SSDHead(out_channels, num_anchors, num_classes)
116+
self.num_anchors = [2 + 2 * len(r) for r in aspect_ratios]
117+
self.head = SSDHead(out_channels, self.num_anchors, num_classes)
85118

86119
self.anchor_generator = DBoxGenerator(size, feature_map_sizes, aspect_ratios)
87120

@@ -97,7 +130,8 @@ def __init__(self, backbone: nn.Module, num_classes: int,
97130
image_mean = [0.485, 0.456, 0.406]
98131
if image_std is None:
99132
image_std = [0.229, 0.224, 0.225]
100-
self.transform = GeneralizedRCNNTransform(size, size, image_mean, image_std)
133+
self.transform = GeneralizedRCNNTransform(size, size, image_mean, image_std,
134+
size_divisible=1) # TODO: Discuss/refactor this workaround
101135

102136
self.score_thresh = score_thresh
103137
self.nms_thresh = nms_thresh
@@ -107,6 +141,15 @@ def __init__(self, backbone: nn.Module, num_classes: int,
107141
# used only on torchscript mode
108142
self._has_warned = False
109143

144+
def _anchors_per_level(self, features, HWA):
145+
# TODO: Discuss/refactor this workaround
146+
num_anchors_per_level = [x.size(2) * x.size(3) * anchors for x, anchors in zip(features, self.num_anchors)]
147+
HW = 0
148+
for v in num_anchors_per_level:
149+
HW += v
150+
A = HWA // HW
151+
return [hw * A for hw in num_anchors_per_level]
152+
110153
def compute_loss(self, targets: List[Dict[str, Tensor]], head_outputs: Dict[str, Tensor],
111154
anchors: List[Tensor]) -> Dict[str, Tensor]:
112155
pass
@@ -203,7 +246,7 @@ def ssd_vgg16(pretrained: bool = False, progress: bool = True, num_classes: int
203246
pretrained_backbone = False
204247

205248
backbone = _vgg_backbone("vgg16", pretrained_backbone, trainable_layers=trainable_backbone_layers)
206-
model = SSD(backbone, num_classes, **kwargs)
249+
model = SSD(backbone, num_classes, **kwargs) # TODO: fix initializations in all new layers
207250
if pretrained:
208251
pass # TODO: load pre-trained COCO weights
209252
return model

torchvision/models/detection/transform.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,15 @@ class GeneralizedRCNNTransform(nn.Module):
6666
It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
6767
"""
6868

69-
def __init__(self, min_size, max_size, image_mean, image_std):
69+
def __init__(self, min_size, max_size, image_mean, image_std, size_divisible=32):
7070
super(GeneralizedRCNNTransform, self).__init__()
7171
if not isinstance(min_size, (list, tuple)):
7272
min_size = (min_size,)
7373
self.min_size = min_size
7474
self.max_size = max_size
7575
self.image_mean = image_mean
7676
self.image_std = image_std
77+
self.size_divisible = size_divisible
7778

7879
def forward(self,
7980
images, # type: List[Tensor]
@@ -107,7 +108,7 @@ def forward(self,
107108
targets[i] = target_index
108109

109110
image_sizes = [img.shape[-2:] for img in images]
110-
images = self.batch_images(images)
111+
images = self.batch_images(images, size_divisible=self.size_divisible)
111112
image_sizes_list: List[Tuple[int, int]] = []
112113
for image_size in image_sizes:
113114
assert len(image_size) == 2

0 commit comments

Comments
 (0)