Skip to content

Commit e173b8f

Browse files
committed
Extend SqueezeExcitation to support custom min_value and activation.
1 parent 447a336 commit e173b8f

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

torchvision/models/mobilenetv3.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,24 @@
2020

2121
class SqueezeExcitation(nn.Module):
2222
# Implemented as described at Figure 4 of the MobileNetV3 paper
23-
def __init__(self, input_channels: int, squeeze_factor: int = 4):
23+
def __init__(self, input_channels: int, squeeze_factor: int = 4, min_value: Optional[int] = None,
24+
activation_fn: Callable[..., Tensor] = F.hardsigmoid):
2425
super().__init__()
25-
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
26+
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8, min_value)
2627
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
2728
self.relu = nn.ReLU(inplace=True)
2829
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
30+
self.activation_fn = activation_fn
2931

30-
def _scale(self, input: Tensor, inplace: bool) -> Tensor:
32+
def _scale(self, input: Tensor) -> Tensor:
3133
scale = F.adaptive_avg_pool2d(input, 1)
3234
scale = self.fc1(scale)
3335
scale = self.relu(scale)
3436
scale = self.fc2(scale)
35-
return F.hardsigmoid(scale, inplace=inplace)
37+
return self.activation_fn(scale)
3638

3739
def forward(self, input: Tensor) -> Tensor:
38-
scale = self._scale(input, True)
40+
scale = self._scale(input)
3941
return scale * input
4042

4143

torchvision/models/quantization/mobilenetv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, *args, **kwargs):
2222
self.skip_mul = nn.quantized.FloatFunctional()
2323

2424
def forward(self, input: Tensor) -> Tensor:
25-
return self.skip_mul.mul(self._scale(input, False), input)
25+
return self.skip_mul.mul(self._scale(input), input)
2626

2727
def fuse_model(self):
2828
fuse_modules(self, ['fc1', 'relu'], inplace=True)

0 commit comments

Comments
 (0)