Skip to content

Commit 03bc52c

Browse files
committed
Moving vgg's extra layers a separate class + L2 scaling.
1 parent bffe4bc commit 03bc52c

File tree

1 file changed

+65
-73
lines changed
  • torchvision/models/detection

1 file changed

+65
-73
lines changed

torchvision/models/detection/ssd.py

Lines changed: 65 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import torch
2+
import torch.nn.functional as F
3+
14
from torch import nn, Tensor
25
from typing import Dict, List, Optional, Tuple
36

@@ -73,17 +76,73 @@ def forward(self, images: List[Tensor],
7376
pass
7477

7578

76-
class MultiFeatureMap(nn.Module):
79+
class SSDFeatureExtractorVGG(nn.Module):
80+
81+
OUT_CHANNELS = (512, 1024, 512, 256, 256, 256)
7782

78-
def __init__(self, feature_maps: nn.ModuleList):
83+
def __init__(self, backbone: nn.Module):
7984
super().__init__()
80-
self.feature_maps = feature_maps
85+
86+
# Patch ceil_mode for all maxpool layers of backbone to get the same WxH output sizes as the paper
87+
penultimate_block_pos = ultimate_block_pos = None
88+
for i, layer in enumerate(backbone):
89+
if isinstance(layer, nn.MaxPool2d):
90+
layer.ceil_mode = True
91+
penultimate_block_pos = ultimate_block_pos
92+
ultimate_block_pos = i
93+
94+
# parameters used for L2 regularization + rescaling
95+
self.scale_weight = nn.Parameter(torch.ones(self.OUT_CHANNELS[0]) * 20)
96+
97+
# Multiple Feature maps - page 4, Fig 2 of SSD paper
98+
self.block1 = nn.Sequential(
99+
*backbone[:penultimate_block_pos] # until conv4_3
100+
)
101+
self.block2 = nn.Sequential(
102+
*backbone[penultimate_block_pos:-1], # until conv5_3, skip maxpool5
103+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), # add modified maxpool5
104+
nn.Conv2d(in_channels=self.OUT_CHANNELS[0],
105+
out_channels=1024, kernel_size=3, padding=6, dilation=6), # FC6 with atrous
106+
nn.ReLU(inplace=True),
107+
nn.Conv2d(in_channels=1024, out_channels=self.OUT_CHANNELS[1], kernel_size=1), # FC7
108+
nn.ReLU(inplace=True)
109+
)
110+
self.block3 = nn.Sequential(
111+
nn.Conv2d(self.OUT_CHANNELS[1], 256, kernel_size=1),
112+
nn.ReLU(inplace=True),
113+
nn.Conv2d(256, self.OUT_CHANNELS[2], kernel_size=3, padding=1, stride=2), # conv8_2
114+
nn.ReLU(inplace=True),
115+
)
116+
self.block4 = nn.Sequential(
117+
nn.Conv2d(self.OUT_CHANNELS[2], 128, kernel_size=1),
118+
nn.ReLU(inplace=True),
119+
nn.Conv2d(128, self.OUT_CHANNELS[3], kernel_size=3, padding=1, stride=2), # conv9_2
120+
nn.ReLU(inplace=True),
121+
)
122+
self.block5 = nn.Sequential(
123+
nn.Conv2d(self.OUT_CHANNELS[3], 128, kernel_size=1),
124+
nn.ReLU(inplace=True),
125+
nn.Conv2d(128, self.OUT_CHANNELS[4], kernel_size=3), # conv10_2
126+
nn.ReLU(inplace=True),
127+
)
128+
self.block6 = nn.Sequential(
129+
nn.Conv2d(self.OUT_CHANNELS[4], 128, kernel_size=1),
130+
nn.ReLU(inplace=True),
131+
nn.Conv2d(128, self.OUT_CHANNELS[5], kernel_size=3), # conv11_2
132+
nn.ReLU(inplace=True),
133+
)
81134

82135
def forward(self, x):
83-
output = []
84-
for block in self.feature_maps:
136+
# L2 regularization + Rescaling of 1st block's feature map
137+
x = self.block1(x)
138+
rescaled = self.scale_weight.view(1, -1, 1, 1) * F.normalize(x)
139+
output = [rescaled]
140+
141+
# Calculating Feature maps for the rest blocks
142+
for block in (self.block2, self.block3, self.block4, self.block5, self.block6):
85143
x = block(x)
86144
output.append(x)
145+
87146
return output
88147

89148

@@ -102,74 +161,7 @@ def _vgg_mfm_backbone(backbone_name, pretrained, trainable_layers=3):
102161
for parameter in b.parameters():
103162
parameter.requires_grad_(False)
104163

105-
# Patch ceil_mode for all maxpool layers of backbone to get the same outputs as Fig2 of SSD papers
106-
for layer in backbone:
107-
if isinstance(layer, nn.MaxPool2d):
108-
layer.ceil_mode = True
109-
110-
# Multiple Feature map definition - page 4, Fig 2 of SSD paper
111-
def build_feature_map_block(layers, out_channels):
112-
block = nn.Sequential(*layers)
113-
block.out_channels = out_channels
114-
return block
115-
116-
penultimate_block_index = stage_indices[-2]
117-
feature_maps = nn.ModuleList([
118-
build_feature_map_block(
119-
backbone[:penultimate_block_index], # until conv4_3
120-
# TODO: add L2 nomarlization + scaling?
121-
512
122-
),
123-
build_feature_map_block(
124-
(
125-
*backbone[penultimate_block_index:-1], # until conv5_3, skip last maxpool
126-
nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), # add modified maxpool5
127-
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6), # FC6 with atrous
128-
nn.ReLU(inplace=True),
129-
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), # FC7
130-
nn.ReLU(inplace=True)
131-
),
132-
1024
133-
),
134-
build_feature_map_block(
135-
(
136-
nn.Conv2d(1024, 256, kernel_size=1),
137-
nn.ReLU(inplace=True),
138-
nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2), # conv8_2
139-
nn.ReLU(inplace=True),
140-
),
141-
512,
142-
),
143-
build_feature_map_block(
144-
(
145-
nn.Conv2d(512, 128, kernel_size=1),
146-
nn.ReLU(inplace=True),
147-
nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), # conv9_2
148-
nn.ReLU(inplace=True),
149-
),
150-
256,
151-
),
152-
build_feature_map_block(
153-
(
154-
nn.Conv2d(256, 128, kernel_size=1),
155-
nn.ReLU(inplace=True),
156-
nn.Conv2d(128, 256, kernel_size=3), # conv10_2
157-
nn.ReLU(inplace=True),
158-
),
159-
256,
160-
),
161-
build_feature_map_block(
162-
(
163-
nn.Conv2d(256, 128, kernel_size=1),
164-
nn.ReLU(inplace=True),
165-
nn.Conv2d(128, 256, kernel_size=3), # conv11_2
166-
nn.ReLU(inplace=True),
167-
),
168-
256,
169-
),
170-
])
171-
172-
return MultiFeatureMap(feature_maps)
164+
return SSDVGGFeatureExtractor(backbone)
173165

174166

175167
def ssd_vgg16(pretrained=False, progress=True,

0 commit comments

Comments
 (0)