|
20 | 20 |
|
21 | 21 | class SqueezeExcitation(nn.Module):
|
22 | 22 | # Implemented as described at Figure 4 of the MobileNetV3 paper
|
23 |
| - def __init__(self, input_channels: int, squeeze_factor: int = 4): |
| 23 | + def __init__(self, input_channels: int, squeeze_factor: int = 4, min_value: Optional[int] = None, |
| 24 | + activation_fn: Callable[..., Tensor] = F.hardsigmoid): |
24 | 25 | super().__init__()
|
25 |
| - squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) |
| 26 | + squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8, min_value) |
26 | 27 | self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
|
27 | 28 | self.relu = nn.ReLU(inplace=True)
|
28 | 29 | self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)
|
| 30 | + self.activation_fn = activation_fn |
29 | 31 |
|
30 |
| - def _scale(self, input: Tensor, inplace: bool) -> Tensor: |
| 32 | + def _scale(self, input: Tensor) -> Tensor: |
31 | 33 | scale = F.adaptive_avg_pool2d(input, 1)
|
32 | 34 | scale = self.fc1(scale)
|
33 | 35 | scale = self.relu(scale)
|
34 | 36 | scale = self.fc2(scale)
|
35 |
| - return F.hardsigmoid(scale, inplace=inplace) |
| 37 | + return self.activation_fn(scale) |
36 | 38 |
|
37 | 39 | def forward(self, input: Tensor) -> Tensor:
|
38 |
| - scale = self._scale(input, True) |
| 40 | + scale = self._scale(input) |
39 | 41 | return scale * input
|
40 | 42 |
|
41 | 43 |
|
|
0 commit comments