From 411ce2516280c5ebd7f9377e394e29c62a919fb3 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 19 Aug 2021 11:51:54 +0100 Subject: [PATCH 01/19] Adding code skeleton --- torchvision/models/efficientnet.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 torchvision/models/efficientnet.py diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py new file mode 100644 index 00000000000..d865e3d6bc8 --- /dev/null +++ b/torchvision/models/efficientnet.py @@ -0,0 +1,27 @@ +import torch + +from torch import nn, Tensor +from torch.nn import functional as F +from typing import Any + +from .._internally_replaced_utils import load_state_dict_from_url + +# TODO: refactor this to a common place? +from torchvision.models.mobilenetv2 import ConvBNActivation +from torchvision.models.mobilenetv3 import SqueezeExcitation + + +class MBConvConfig: + pass + + +class MBConv(nn.Module): + pass + + +class EfficientNet(nn.Module): + pass + + +def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + pass From 447a3362fb2eee62134da4352a9b32f16b54d196 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 19 Aug 2021 18:16:20 +0100 Subject: [PATCH 02/19] Adding MBConvConfig. --- torchvision/models/efficientnet.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index d865e3d6bc8..2d5fd219288 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -2,17 +2,38 @@ from torch import nn, Tensor from torch.nn import functional as F -from typing import Any +from typing import Any, Optional from .._internally_replaced_utils import load_state_dict_from_url # TODO: refactor this to a common place? -from torchvision.models.mobilenetv2 import ConvBNActivation +from torchvision.models.mobilenetv2 import ConvBNActivation, _make_divisible from torchvision.models.mobilenetv3 import SqueezeExcitation +__all__ = [] + + +model_urls = {} + + class MBConvConfig: - pass + # TODO: Add dilation for supporting detection and segmentation pipelines + def __init__(self, + kernel: int, stride: int, + input_channels: int, out_channels: int, expand_ratio: float, se_ratio: float, + skip: bool, width_mult: float): + self.kernel = kernel + self.stride = stride + self.input_channels = self.adjust_channels(input_channels, width_mult) + self.out_channels = self.adjust_channels(out_channels, width_mult) + self.expanded_channels = self.adjust_channels(input_channels, expand_ratio * width_mult) + self.se_channels = self.adjust_channels(input_channels, se_ratio * width_mult, 1) + self.skip = skip + + @staticmethod + def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None): + return _make_divisible(channels * width_mult, 8, min_value) class MBConv(nn.Module): From e173b8f9aae4f265f69306a2c3d2495979d29a3b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 19 Aug 2021 20:03:57 +0100 Subject: [PATCH 03/19] Extend SqueezeExcitation to support custom min_value and activation. --- torchvision/models/mobilenetv3.py | 12 +++++++----- torchvision/models/quantization/mobilenetv3.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index ebe3f510a49..62934cd7255 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -20,22 +20,24 @@ class SqueezeExcitation(nn.Module): # Implemented as described at Figure 4 of the MobileNetV3 paper - def __init__(self, input_channels: int, squeeze_factor: int = 4): + def __init__(self, input_channels: int, squeeze_factor: int = 4, min_value: Optional[int] = None, + activation_fn: Callable[..., Tensor] = F.hardsigmoid): super().__init__() - squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) + squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8, min_value) self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) self.relu = nn.ReLU(inplace=True) self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) + self.activation_fn = activation_fn - def _scale(self, input: Tensor, inplace: bool) -> Tensor: + def _scale(self, input: Tensor) -> Tensor: scale = F.adaptive_avg_pool2d(input, 1) scale = self.fc1(scale) scale = self.relu(scale) scale = self.fc2(scale) - return F.hardsigmoid(scale, inplace=inplace) + return self.activation_fn(scale) def forward(self, input: Tensor) -> Tensor: - scale = self._scale(input, True) + scale = self._scale(input) return scale * input diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 5462af89127..38dfa3893ea 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -22,7 +22,7 @@ def __init__(self, *args, **kwargs): self.skip_mul = nn.quantized.FloatFunctional() def forward(self, input: Tensor) -> Tensor: - return self.skip_mul.mul(self._scale(input, False), input) + return self.skip_mul.mul(self._scale(input), input) def fuse_model(self): fuse_modules(self, ['fc1', 'relu'], inplace=True) From bb1bb17d572c5651da3ce2e42cd6641a19c6aa89 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 20 Aug 2021 09:00:05 +0100 Subject: [PATCH 04/19] Implement MBConv. --- torchvision/models/efficientnet.py | 67 +++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 11 deletions(-) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 2d5fd219288..52219904984 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -2,7 +2,7 @@ from torch import nn, Tensor from torch.nn import functional as F -from typing import Any, Optional +from typing import Any, Callable, List, Optional from .._internally_replaced_utils import load_state_dict_from_url @@ -11,33 +11,78 @@ from torchvision.models.mobilenetv3 import SqueezeExcitation -__all__ = [] +__all__ = ["EfficientNet"] -model_urls = {} +model_urls = { + "efficientnet_b0": "", # TODO: Add weights +} + + +def stochastic_depth(x: Tensor, drop_rate: float) -> Tensor: + survival_rate = 1.0 - drop_rate + keep = torch.rand(size=(x.size(0), ), dtype=x.dtype, device=x.device) > drop_rate + keep = keep[(None, ) * (x.ndim - 1)].T + return x / survival_rate * keep class MBConvConfig: - # TODO: Add dilation for supporting detection and segmentation pipelines def __init__(self, - kernel: int, stride: int, - input_channels: int, out_channels: int, expand_ratio: float, se_ratio: float, - skip: bool, width_mult: float): + kernel: int, stride: int, dilation: int, + input_channels: int, out_channels: int, expand_ratio: float, + width_mult: float) -> None: self.kernel = kernel self.stride = stride + self.dilation = dilation self.input_channels = self.adjust_channels(input_channels, width_mult) self.out_channels = self.adjust_channels(out_channels, width_mult) self.expanded_channels = self.adjust_channels(input_channels, expand_ratio * width_mult) - self.se_channels = self.adjust_channels(input_channels, se_ratio * width_mult, 1) - self.skip = skip @staticmethod - def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None): + def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int: return _make_divisible(channels * width_mult, 8, min_value) class MBConv(nn.Module): - pass + def __init__(self, cnf: MBConvConfig, norm_layer: Callable[..., nn.Module], + se_layer: Callable[..., nn.Module] = SqueezeExcitation) -> None: + super().__init__() + if not (1 <= cnf.stride <= 2): + raise ValueError('illegal stride value') + + self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels + + layers: List[nn.Module] = [] + activation_layer = nn.SiLU + + # expand + if cnf.expanded_channels != cnf.input_channels: + layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=activation_layer)) + + # depthwise + stride = 1 if cnf.dilation > 1 else cnf.stride + layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, + stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, + norm_layer=norm_layer, activation_layer=activation_layer)) + + # squeeze and excitation + layers.append(se_layer(cnf.expanded_channels, min_value=1, activation_fn=F.sigmoid)) + + # project + layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, + activation_layer=nn.Identity)) + + self.block = nn.Sequential(*layers) + self.out_channels = cnf.out_channels + + def forward(self, input: Tensor, drop_rate: Optional[float] = None) -> Tensor: + result = self.block(input) + if self.use_res_connect: + if self.training and drop_rate: + result = stochastic_depth(result, drop_rate) + result += input + return result class EfficientNet(nn.Module): From d15bf7805ffce91faa2f91b37d4bbcfa1c387d37 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 20 Aug 2021 14:49:30 +0100 Subject: [PATCH 05/19] Replace stochastic_depth with operator. --- torchvision/models/efficientnet.py | 13 +++---------- torchvision/ops/stochastic_depth.py | 4 ++-- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 52219904984..73ba5962b76 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -5,6 +5,7 @@ from typing import Any, Callable, List, Optional from .._internally_replaced_utils import load_state_dict_from_url +from torchvision.ops import stochastic_depth # TODO: refactor this to a common place? from torchvision.models.mobilenetv2 import ConvBNActivation, _make_divisible @@ -19,13 +20,6 @@ } -def stochastic_depth(x: Tensor, drop_rate: float) -> Tensor: - survival_rate = 1.0 - drop_rate - keep = torch.rand(size=(x.size(0), ), dtype=x.dtype, device=x.device) > drop_rate - keep = keep[(None, ) * (x.ndim - 1)].T - return x / survival_rate * keep - - class MBConvConfig: def __init__(self, kernel: int, stride: int, dilation: int, @@ -76,11 +70,10 @@ def __init__(self, cnf: MBConvConfig, norm_layer: Callable[..., nn.Module], self.block = nn.Sequential(*layers) self.out_channels = cnf.out_channels - def forward(self, input: Tensor, drop_rate: Optional[float] = None) -> Tensor: + def forward(self, input: Tensor, drop_rate: float = 0.0) -> Tensor: result = self.block(input) if self.use_res_connect: - if self.training and drop_rate: - result = stochastic_depth(result, drop_rate) + result = stochastic_depth(result, drop_rate, "row", training=self.training) result += input return result diff --git a/torchvision/ops/stochastic_depth.py b/torchvision/ops/stochastic_depth.py index f3338242a76..0b95e7cca67 100644 --- a/torchvision/ops/stochastic_depth.py +++ b/torchvision/ops/stochastic_depth.py @@ -22,12 +22,12 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) """ if p < 0.0 or p > 1.0: raise ValueError("drop probability has to be between 0 and 1, but got {}".format(p)) + if mode not in ["batch", "row"]: + raise ValueError("mode has to be either 'batch' or 'row', but got {}".format(mode)) if not training or p == 0.0: return input survival_rate = 1.0 - p - if mode not in ["batch", "row"]: - raise ValueError("mode has to be either 'batch' or 'row', but got {}".format(mode)) size = [1] * input.ndim if mode == "row": size[0] = input.shape[0] From b78399b028925a80d084a108767d073b6d980f76 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 20 Aug 2021 19:07:59 +0100 Subject: [PATCH 06/19] Adding the rest of the EfficientNet implementation --- ...odelTester.test_efficientnet_b0_expect.pkl | Bin 0 -> 939 bytes torchvision/models/__init__.py | 1 + torchvision/models/efficientnet.py | 180 ++++++++++++++++-- 3 files changed, 166 insertions(+), 15 deletions(-) create mode 100644 test/expect/ModelTester.test_efficientnet_b0_expect.pkl diff --git a/test/expect/ModelTester.test_efficientnet_b0_expect.pkl b/test/expect/ModelTester.test_efficientnet_b0_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..8d56c45f46e79f29ebbb9d71ace2418eb6d1d12b GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK66e& zzBa*CZkstZcOCh&;`pTFE9(w7ue=q$cjf2i_*Gk!A8J~$v1@J!dbQ#XhoqL@@&?WS zrxs{dt@)rC)b@L2@{g`npRT%THq21e*z%`MW6j?$nmI4)S8`=quKZNysCB|OKr>ma zP^)rd{;F>~zOPK&thOpyH+WTJ$5PEB>K9iqTs*ZbrN2otf7i~HENyI>>o+#7`rhoM zb^Za5)`ZEk)Ip&YSv&jO31CQpFz(ReXRwBcR#|FMF)$X~oXm*~E~JoyFparDHeZ~V z9?Ar?6@&x489@|0O(Msk07wD_pr=rD-N=68L(%yP$V1kxZ-A~B*;V`~dL@7^gz1Hb zL4Y?Kn+{Zw9J4N5IVdrM0F2%a;WA7DdlKYbHc;MR@PsM=Wr6^2RyL3rGZ2E*L(~EQ De!%?h literal 0 HcmV?d00001 diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 283e544e98e..3c1519c1b42 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -8,6 +8,7 @@ from .mobilenet import * from .mnasnet import * from .shufflenetv2 import * +from .efficientnet import * from . import segmentation from . import detection from . import video diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 73ba5962b76..cca46661cfd 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -1,46 +1,55 @@ +import copy +import math import torch +from functools import partial from torch import nn, Tensor from torch.nn import functional as F -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Sequence from .._internally_replaced_utils import load_state_dict_from_url -from torchvision.ops import stochastic_depth +from torchvision.ops import StochasticDepth # TODO: refactor this to a common place? from torchvision.models.mobilenetv2 import ConvBNActivation, _make_divisible from torchvision.models.mobilenetv3 import SqueezeExcitation -__all__ = ["EfficientNet"] +__all__ = ["EfficientNet", "efficientnet_b0"] model_urls = { - "efficientnet_b0": "", # TODO: Add weights + "efficientnet_b0": None, # TODO: Add weights } class MBConvConfig: + # Stores information listed at Tables 1 of the EfficientNet paper def __init__(self, - kernel: int, stride: int, dilation: int, - input_channels: int, out_channels: int, expand_ratio: float, - width_mult: float) -> None: + expand_ratio: float, kernel: int, stride: int, + input_channels: int, out_channels: int, num_layers: int, + width_mult: float, depth_mult: float) -> None: + self.expanded_channels = self.adjust_channels(input_channels, expand_ratio * width_mult) self.kernel = kernel self.stride = stride - self.dilation = dilation self.input_channels = self.adjust_channels(input_channels, width_mult) self.out_channels = self.adjust_channels(out_channels, width_mult) - self.expanded_channels = self.adjust_channels(input_channels, expand_ratio * width_mult) + self.num_layers = self.adjust_depth(num_layers, depth_mult) @staticmethod def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int: return _make_divisible(channels * width_mult, 8, min_value) + @staticmethod + def adjust_depth(num_layers: int, depth_mult: float): + return int(math.ceil(num_layers * depth_mult)) + class MBConv(nn.Module): - def __init__(self, cnf: MBConvConfig, norm_layer: Callable[..., nn.Module], + def __init__(self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module], se_layer: Callable[..., nn.Module] = SqueezeExcitation) -> None: super().__init__() + if not (1 <= cnf.stride <= 2): raise ValueError('illegal stride value') @@ -55,9 +64,8 @@ def __init__(self, cnf: MBConvConfig, norm_layer: Callable[..., nn.Module], norm_layer=norm_layer, activation_layer=activation_layer)) # depthwise - stride = 1 if cnf.dilation > 1 else cnf.stride layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, - stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, + stride=cnf.stride, groups=cnf.expanded_channels, norm_layer=norm_layer, activation_layer=activation_layer)) # squeeze and excitation @@ -68,19 +76,161 @@ def __init__(self, cnf: MBConvConfig, norm_layer: Callable[..., nn.Module], activation_layer=nn.Identity)) self.block = nn.Sequential(*layers) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.out_channels = cnf.out_channels def forward(self, input: Tensor, drop_rate: float = 0.0) -> Tensor: result = self.block(input) if self.use_res_connect: - result = stochastic_depth(result, drop_rate, "row", training=self.training) + result = self.stochastic_depth(result) result += input return result class EfficientNet(nn.Module): - pass + def __init__( + self, + inverted_residual_setting: List[MBConvConfig], + dropout: float, + stochastic_depth_prob: float = 0.2, + num_classes: int = 1000, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any + ) -> None: + """ + EfficientNet main class + + Args: + inverted_residual_setting (List[MBConvConfig]): Network structure + dropout (float): The droupout probability + stochastic_depth_prob (float): The stochastic depth probability + num_classes (int): Number of classes + block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet + norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use + """ + super().__init__() + + if not inverted_residual_setting: + raise ValueError("The inverted_residual_setting should not be empty") + elif not (isinstance(inverted_residual_setting, Sequence) and + all([isinstance(s, MBConvConfig) for s in inverted_residual_setting])): + raise TypeError("The inverted_residual_setting should be List[MBConvConfig]") + + if block is None: + block = MBConv + + if norm_layer is None: + norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) + + layers: List[nn.Module] = [] + + # building first layer + firstconv_output_channels = inverted_residual_setting[0].input_channels + layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, + activation_layer=nn.SiLU)) + + # building inverted residual blocks + total_stage_blocks = sum([cnf.num_layers for cnf in inverted_residual_setting]) + stage_block_id = 0 + for cnf in inverted_residual_setting: + stage: List[nn.Module] = [] + for _ in range(cnf.num_layers): + # copy to avoid modifications. shallow copy is enough + block_cnf = copy.copy(cnf) + + # overwrite info if not the first conv in the stage + if stage: + block_cnf.input_channels = block_cnf.out_channels + block_cnf.stride = 1 + + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks + + stage.append(block(block_cnf, sd_prob, norm_layer)) + stage_block_id += 1 + + layers.append(nn.Sequential(*stage)) + + # building last several layers + lastconv_input_channels = inverted_residual_setting[-1].out_channels + lastconv_output_channels = 4 * lastconv_input_channels + layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=nn.SiLU)) + + self.features = nn.Sequential(*layers) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Sequential( + nn.Dropout(p=dropout, inplace=True), + nn.Linear(lastconv_output_channels, num_classes), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + init_range = 1.0 / math.sqrt(m.out_features) + nn.init.uniform_(m.weight, -init_range, init_range) + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.features(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + x = self.classifier(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _efficientnet_conf(width_mult: float, depth_mult: float, **kwargs: Any) -> List[MBConvConfig]: + bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult) + inverted_residual_setting = [ + bneck_conf(1, 3, 1, 32, 16, 1), + bneck_conf(6, 3, 2, 16, 24, 2), + bneck_conf(6, 5, 2, 24, 40, 2), + bneck_conf(6, 3, 2, 40, 80, 3), + bneck_conf(6, 5, 1, 80, 112, 3), + bneck_conf(6, 5, 2, 112, 192, 4), + bneck_conf(6, 3, 1, 192, 320, 1), + ] + return inverted_residual_setting + + +def _efficientnet_model( + arch: str, + inverted_residual_setting: List[MBConvConfig], + dropout: float, + pretrained: bool, + progress: bool, + **kwargs: Any +) -> EfficientNet: + model = EfficientNet(inverted_residual_setting, dropout, **kwargs) + if pretrained: + if model_urls.get(arch, None) is None: + raise ValueError("No checkpoint is available for model type {}".format(arch)) + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: - pass + """ + Constructs a EfficientNet B0 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.0, depth_mult=1.0, **kwargs) + return _efficientnet_model("efficientnet_b0", inverted_residual_setting, 0.2, pretrained, progress, **kwargs) From 990826bb001cc9ff487dfe8150599320a114a01d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 21 Aug 2021 00:14:51 +0100 Subject: [PATCH 07/19] Update torchvision/models/efficientnet.py --- torchvision/models/efficientnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index cca46661cfd..d27006c79b9 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -79,7 +79,7 @@ def __init__(self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer: self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.out_channels = cnf.out_channels - def forward(self, input: Tensor, drop_rate: float = 0.0) -> Tensor: + def forward(self, input: Tensor) -> Tensor: result = self.block(input) if self.use_res_connect: result = self.stochastic_depth(result) From 697eee9249f3fa3ddbf779143208480fa376ac67 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 23 Aug 2021 14:01:21 +0100 Subject: [PATCH 08/19] Replacing 1st activation of SE with SiLU. --- torchvision/models/efficientnet.py | 3 ++- torchvision/models/mobilenetv3.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index d27006c79b9..8601fc33f06 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -69,7 +69,8 @@ def __init__(self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer: norm_layer=norm_layer, activation_layer=activation_layer)) # squeeze and excitation - layers.append(se_layer(cnf.expanded_channels, min_value=1, activation_fn=F.sigmoid)) + layers.append(se_layer(cnf.expanded_channels, min_value=1, + activation_layer=activation_layer, activation_fn=F.sigmoid)) # project layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 62934cd7255..ee6c4104859 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -21,11 +21,12 @@ class SqueezeExcitation(nn.Module): # Implemented as described at Figure 4 of the MobileNetV3 paper def __init__(self, input_channels: int, squeeze_factor: int = 4, min_value: Optional[int] = None, + activation_layer: Callable[..., nn.Module] = nn.ReLU, activation_fn: Callable[..., Tensor] = F.hardsigmoid): super().__init__() squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8, min_value) self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) - self.relu = nn.ReLU(inplace=True) + self.relu = activation_layer() self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) self.activation_fn = activation_fn From 8ff76044e9a5b99a71e4ef892f2fd0579322d6e9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 23 Aug 2021 16:03:14 +0100 Subject: [PATCH 09/19] Adding efficientnet_b3. --- ...odelTester.test_efficientnet_b3_expect.pkl | Bin 0 -> 939 bytes torchvision/models/efficientnet.py | 20 +++++++++++++++--- 2 files changed, 17 insertions(+), 3 deletions(-) create mode 100644 test/expect/ModelTester.test_efficientnet_b3_expect.pkl diff --git a/test/expect/ModelTester.test_efficientnet_b3_expect.pkl b/test/expect/ModelTester.test_efficientnet_b3_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..f46d16c21b218d0ad0364b1f8d81fe4721e0fcad GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5{Jo6U?mURgB6$F#A<{sS-#?$m)weF2C|y=e}1e`y|qen#tZ|^&M6O8)|t&( z`5}T`Gh2Mi$}LlvRxFv!sM%k4OXG{}M2#b|jT+PLZqO*X_j|>oC9_u?+O4Cx=ID%- z8*VON=51-N9&>Qn$^zjhE1dS5Ybtz9QGYPYP2I1ARdaew-inr|IhxCUX0PP2TB7cf z@nYq<`%EhpX7AQ;Ry?(0!~dn43(R|0Sk0?<<^x^84Y@uBE^1>_;?);B=ci|i_X6ulBa7sB*H z!yv$$jZFutNRC+-t{jw@K>$W?hj1AtfjtTGE*mIsFnB^0fHFaVH!B-Rj2Q?)>LF?Y D{CM|2 literal 0 HcmV?d00001 diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 8601fc33f06..683bdd250b4 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -15,11 +15,12 @@ from torchvision.models.mobilenetv3 import SqueezeExcitation -__all__ = ["EfficientNet", "efficientnet_b0"] +__all__ = ["EfficientNet", "efficientnet_b0", "efficientnet_b3"] -model_urls = { - "efficientnet_b0": None, # TODO: Add weights +model_urls = { # TODO: Add weights + "efficientnet_b0": None, + "efficientnet_b3": None, } @@ -235,3 +236,16 @@ def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: A """ inverted_residual_setting = _efficientnet_conf(width_mult=1.0, depth_mult=1.0, **kwargs) return _efficientnet_model("efficientnet_b0", inverted_residual_setting, 0.2, pretrained, progress, **kwargs) + + +def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B3 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.2, depth_mult=1.4, **kwargs) + return _efficientnet_model("efficientnet_b3", inverted_residual_setting, 0.3, pretrained, progress, **kwargs) From ca9e619907b6572f74ddfd1bae5774cf53851222 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 23 Aug 2021 19:26:06 +0100 Subject: [PATCH 10/19] Replace mobilenetv3 assets with custom. --- ...odelTester.test_efficientnet_b0_expect.pkl | Bin 939 -> 939 bytes ...odelTester.test_efficientnet_b3_expect.pkl | Bin 939 -> 939 bytes torchvision/models/efficientnet.py | 38 +++++++++++++----- torchvision/models/mobilenetv3.py | 15 +++---- .../models/quantization/mobilenetv3.py | 2 +- 5 files changed, 35 insertions(+), 20 deletions(-) diff --git a/test/expect/ModelTester.test_efficientnet_b0_expect.pkl b/test/expect/ModelTester.test_efficientnet_b0_expect.pkl index 8d56c45f46e79f29ebbb9d71ace2418eb6d1d12b..334611a2e8a9d2bbba10a095e532477cf23acc71 100644 GIT binary patch delta 230 zcmV_)2$AWolCjbBd delta 230 zcmV#QdnolE9X)kZ0Rm?0>%{Dvs3{q!kp>3>nF3TZj1@^(Qg$W2fw zV-|5Ld9iP*_PF<{VY4QxV=YsvfrzCk#3#|H0MW{(WRHR=Z@Rgu1cnAFud#xv_k%(z g&)^IykdvAxP)i30S$&(%lMn*X1X+EX&XeQ<#}3ADlmGw# diff --git a/test/expect/ModelTester.test_efficientnet_b3_expect.pkl b/test/expect/ModelTester.test_efficientnet_b3_expect.pkl index f46d16c21b218d0ad0364b1f8d81fe4721e0fcad..3d952a542908742882f26acbdf50b33114b2c523 100644 GIT binary patch delta 230 zcmVlJ}*-!29WV7xO>$p8zzRS!0@4_ gv;ns$<}sqFP)i30*ai?olMn*X1lR@;LzCnJ$D)92od5s; delta 230 zcmVdD>|H2Uq^GFXN*<`DFdZpB{P?IP+Nvp-m@p}c zl;Ej;HkzsMSO+O<7qqFglme)vlL0A@e%dJXJ&`EH9f2s8-LNQh-utNHq?@S1yDTZJ z#h9tE*{7yUIX5RYkgpf`=EI)1JwvVQ?6e3g$Vayzpqh69%=($#UP*eg7! g@NE((0+RKqP)i30Mhv)~lMn*X1V#+Fo|EJP$L8v4 Tensor: + scale = F.adaptive_avg_pool2d(input, 1) + scale = self.fc1(scale) + scale = F.silu(scale, inplace=True) + scale = self.fc2(scale) + return F.hardsigmoid(scale, inplace=True) + + def forward(self, input: Tensor) -> Tensor: + scale = self._scale(input) + return scale * input + + class MBConvConfig: # Stores information listed at Tables 1 of the EfficientNet paper def __init__(self, expand_ratio: float, kernel: int, stride: int, input_channels: int, out_channels: int, num_layers: int, width_mult: float, depth_mult: float) -> None: - self.expanded_channels = self.adjust_channels(input_channels, expand_ratio * width_mult) + self.expand_ratio = expand_ratio self.kernel = kernel self.stride = stride self.input_channels = self.adjust_channels(input_channels, width_mult) self.out_channels = self.adjust_channels(out_channels, width_mult) - self.num_layers = self.adjust_depth(num_layers, depth_mult) + self.num_layers = self.adjust_depth(num_layers, depth_mult) # TODO: add __repr__ @staticmethod def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int: @@ -60,21 +77,22 @@ def __init__(self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer: activation_layer = nn.SiLU # expand - if cnf.expanded_channels != cnf.input_channels: - layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, + expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) + if expanded_channels != cnf.input_channels: + layers.append(ConvBNActivation(cnf.input_channels, expanded_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation_layer)) # depthwise - layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, - stride=cnf.stride, groups=cnf.expanded_channels, + layers.append(ConvBNActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel, + stride=cnf.stride, groups=expanded_channels, norm_layer=norm_layer, activation_layer=activation_layer)) # squeeze and excitation - layers.append(se_layer(cnf.expanded_channels, min_value=1, - activation_layer=activation_layer, activation_fn=F.sigmoid)) + squeeze_channels = max(1, cnf.input_channels // 4) + layers.append(se_layer(expanded_channels, squeeze_channels)) # project - layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, + layers.append(ConvBNActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Identity)) self.block = nn.Sequential(*layers) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index ee6c4104859..ebe3f510a49 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -20,25 +20,22 @@ class SqueezeExcitation(nn.Module): # Implemented as described at Figure 4 of the MobileNetV3 paper - def __init__(self, input_channels: int, squeeze_factor: int = 4, min_value: Optional[int] = None, - activation_layer: Callable[..., nn.Module] = nn.ReLU, - activation_fn: Callable[..., Tensor] = F.hardsigmoid): + def __init__(self, input_channels: int, squeeze_factor: int = 4): super().__init__() - squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8, min_value) + squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) - self.relu = activation_layer() + self.relu = nn.ReLU(inplace=True) self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) - self.activation_fn = activation_fn - def _scale(self, input: Tensor) -> Tensor: + def _scale(self, input: Tensor, inplace: bool) -> Tensor: scale = F.adaptive_avg_pool2d(input, 1) scale = self.fc1(scale) scale = self.relu(scale) scale = self.fc2(scale) - return self.activation_fn(scale) + return F.hardsigmoid(scale, inplace=inplace) def forward(self, input: Tensor) -> Tensor: - scale = self._scale(input) + scale = self._scale(input, True) return scale * input diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 38dfa3893ea..5462af89127 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -22,7 +22,7 @@ def __init__(self, *args, **kwargs): self.skip_mul = nn.quantized.FloatFunctional() def forward(self, input: Tensor) -> Tensor: - return self.skip_mul.mul(self._scale(input), input) + return self.skip_mul.mul(self._scale(input, False), input) def fuse_model(self): fuse_modules(self, ['fc1', 'relu'], inplace=True) From 627dbe58cc4d198a442796f42b7cb682ad4e6a66 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 Aug 2021 13:42:24 +0100 Subject: [PATCH 11/19] Switch to standard sigmoid and reconfiguring BN. --- .../ModelTester.test_efficientnet_b0_expect.pkl | Bin 939 -> 939 bytes .../ModelTester.test_efficientnet_b3_expect.pkl | Bin 939 -> 939 bytes torchvision/models/efficientnet.py | 4 ++-- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/expect/ModelTester.test_efficientnet_b0_expect.pkl b/test/expect/ModelTester.test_efficientnet_b0_expect.pkl index 334611a2e8a9d2bbba10a095e532477cf23acc71..1de871ce0fbea9ddbab7e315b05f864bc5f6fa53 100644 GIT binary patch delta 230 zcmV`f{_4yvc`Y`Q4`Qd6q3>)c!!QD*YE%;hw;{_ExaeH*ySmzbUrMqzRPv0+ff0j gL>IECEbHv3P)i30F~yPIlMn*X1Tn>t-jn14#}|)o3jhEB delta 230 zcmV_)2$AWolCjbBd diff --git a/test/expect/ModelTester.test_efficientnet_b3_expect.pkl b/test/expect/ModelTester.test_efficientnet_b3_expect.pkl index 3d952a542908742882f26acbdf50b33114b2c523..989d6782fe799c4833239c51f08e8375a6592179 100644 GIT binary patch delta 230 zcmV{oLTLmfP%U-DVNW`fu<1;Dob)>0y1{|siXeX+G z@1Lj@43#PPN0Og9&1$S`B5_J!^#cx@V~ud~3ZswD*}(jvI1Fhjqn^CVU%;Q#X}N5j`C+k}Uyz)Y#8 ga`w0=miVKoP)i306e?e_lMn*X1QaS?v6JKi$F9_DE&u=k delta 230 zcmVlJ}*-!29WV7xO>$p8zzRS!0@4_ gv;ns$<}sqFP)i30*ai?olMn*X1lR@;LzCnJ$D)92od5s; diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 10420b6ac42..0c547c34c47 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -34,7 +34,7 @@ def _scale(self, input: Tensor) -> Tensor: scale = self.fc1(scale) scale = F.silu(scale, inplace=True) scale = self.fc2(scale) - return F.hardsigmoid(scale, inplace=True) + return scale.sigmoid() def forward(self, input: Tensor) -> Tensor: scale = self._scale(input) @@ -141,7 +141,7 @@ def __init__( block = MBConv if norm_layer is None: - norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) + norm_layer = partial(nn.BatchNorm2d, momentum=0.003) layers: List[nn.Module] = [] From 4fc26bc5815ee28fb863b9cac763187e640615fe Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 Aug 2021 17:05:08 +0100 Subject: [PATCH 12/19] Reconfiguration of efficientnet. --- references/classification/presets.py | 6 ++++-- references/classification/train.py | 17 +++++++++++++++-- ...ModelTester.test_efficientnet_b0_expect.pkl | Bin 939 -> 939 bytes ...ModelTester.test_efficientnet_b3_expect.pkl | Bin 939 -> 939 bytes torchvision/models/efficientnet.py | 2 +- 5 files changed, 20 insertions(+), 5 deletions(-) diff --git a/references/classification/presets.py b/references/classification/presets.py index 6bb389ba8db..ce5a6fe414f 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -1,4 +1,5 @@ from torchvision.transforms import autoaugment, transforms +from torchvision.transforms.functional import InterpolationMode class ClassificationPresetTrain: @@ -24,10 +25,11 @@ def __call__(self, img): class ClassificationPresetEval: - def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): + def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), + interpolation=InterpolationMode.BILINEAR): self.transforms = transforms.Compose([ - transforms.Resize(resize_size), + transforms.Resize(resize_size, interpolation=interpolation), transforms.CenterCrop(crop_size), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), diff --git a/references/classification/train.py b/references/classification/train.py index b4e9d274662..fcd4767fcd2 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -6,6 +6,7 @@ import torch.utils.data from torch import nn import torchvision +from torchvision.transforms.functional import InterpolationMode import presets import utils @@ -82,7 +83,18 @@ def _get_cache_path(filepath): def load_data(traindir, valdir, args): # Data loading code print("Loading data") - resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224) + resize_size, crop_size = 256, 224 + interpolation = InterpolationMode.BILINEAR + if args.model == 'inception_v3': + resize_size, crop_size = 342, 299 + elif args.model.startswith('efficientnet_'): + sizes = { + 'B0': 224, 'B1': 240, 'B2': 260, 'B3': 300, + 'B4': 380, 'B5': 456, 'B6': 528, 'B7': 600, + } + e_type = args.model.replace('efficientnet_', '').upper() + resize_size = crop_size = sizes[e_type] + interpolation = InterpolationMode.BICUBIC print("Loading training data") st = time.time() @@ -113,7 +125,8 @@ def load_data(traindir, valdir, args): else: dataset_test = torchvision.datasets.ImageFolder( valdir, - presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size)) + presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, + interpolation=interpolation)) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) diff --git a/test/expect/ModelTester.test_efficientnet_b0_expect.pkl b/test/expect/ModelTester.test_efficientnet_b0_expect.pkl index 1de871ce0fbea9ddbab7e315b05f864bc5f6fa53..c299eac15af0db2a6f55347c57e82fd168bc47e9 100644 GIT binary patch delta 230 zcmVlF`cSG0l-?*BR1zwNZ{{h?n-r-40Y@rQ ztY@lIyLPHD=bI>)@kT0&9SSOdEmEezS&k`kwu31+yyK}|^(3j#&=IIC$o(mt`_ZV-m{F+p=q;*#VJ%TAp!~(B@fg~vqt|t+ zhXBT@9`3U#@oe=d)5)qSd|FYeE&q%ugcJWM#Ff;kfb1ozjRh#G##<_?4jXf-`f{_4yvc`Y`Q4`Qd6q3>)c!!QD*YE%;hw;{_ExaeH*ySmzbUrMqzRPv0+ff0j gL>IECEbHv3P)i30F~yPIlMn*X1Tn>t-jn14#}|)o3jhEB diff --git a/test/expect/ModelTester.test_efficientnet_b3_expect.pkl b/test/expect/ModelTester.test_efficientnet_b3_expect.pkl index 989d6782fe799c4833239c51f08e8375a6592179..a93d3a5612729867bacdc1b65a183f6ee8dcb654 100644 GIT binary patch delta 230 zcmVH{e!ZCt2JF~X^f?=dMetD>osRT`=dmnEtK zDxRorZj&iuv5=|!zEmkvC{L)=35uz}jNm8Iuf(aK0;(lG|Bop!Y8WZ{bsnj8uuG}V zU|6YO&kHA{_n4{qi&H7EhYqRS|468;9Hgn(;!3EVwTLK~bHq`n7N*ImlDc>)z(ND2 zmTZNojHO_xtn%wA{RP)i30&ZDhHlMn*X1kR(aMw8?M$KkebWdHyG delta 230 zcmV{oLTLmfP%U-DVNW`fu<1;Dob)>0y1{|siXeX+G z@1Lj@43#PPN0Og9&1$S`B5_J!^#cx@V~ud~3ZswD*}(jvI1Fhjqn^CVU%;Q#X}N5j`C+k}Uyz)Y#8 ga`w0=miVKoP)i306e?e_lMn*X1QaS?v6JKi$F9_DE&u=k diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 0c547c34c47..400c8a3855b 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -141,7 +141,7 @@ def __init__( block = MBConv if norm_layer is None: - norm_layer = partial(nn.BatchNorm2d, momentum=0.003) + norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) layers: List[nn.Module] = [] From 14ce91f66feaa5d6f54ecc579932288f8bd2db91 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 Aug 2021 18:49:27 +0100 Subject: [PATCH 13/19] Add repr --- references/classification/train.py | 6 +++--- torchvision/models/efficientnet.py | 13 ++++++++++++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index fcd4767fcd2..a7dcce21411 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -89,10 +89,10 @@ def load_data(traindir, valdir, args): resize_size, crop_size = 342, 299 elif args.model.startswith('efficientnet_'): sizes = { - 'B0': 224, 'B1': 240, 'B2': 260, 'B3': 300, - 'B4': 380, 'B5': 456, 'B6': 528, 'B7': 600, + 'b0': 224, 'b1': 240, 'b2': 260, 'b3': 300, + 'b4': 380, 'b5': 456, 'b6': 528, 'b7': 600, } - e_type = args.model.replace('efficientnet_', '').upper() + e_type = args.model.replace('efficientnet_', '') resize_size = crop_size = sizes[e_type] interpolation = InterpolationMode.BICUBIC diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 400c8a3855b..e9293b6f7e9 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -52,7 +52,18 @@ def __init__(self, self.stride = stride self.input_channels = self.adjust_channels(input_channels, width_mult) self.out_channels = self.adjust_channels(out_channels, width_mult) - self.num_layers = self.adjust_depth(num_layers, depth_mult) # TODO: add __repr__ + self.num_layers = self.adjust_depth(num_layers, depth_mult) + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'expand_ratio={expand_ratio}' + s += ', kernel={kernel}' + s += ', stride={stride}' + s += ', input_channels={input_channels}' + s += ', out_channels={out_channels}' + s += ', num_layers={num_layers}' + s += ')' + return s.format(**self.__dict__) @staticmethod def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int: From 0dca77d9a65375488ab98811712a069813dd0433 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 Aug 2021 18:51:30 +0100 Subject: [PATCH 14/19] Add weights. --- ...odelTester.test_efficientnet_b1_expect.pkl | Bin 0 -> 939 bytes ...odelTester.test_efficientnet_b2_expect.pkl | Bin 0 -> 939 bytes ...odelTester.test_efficientnet_b4_expect.pkl | Bin 0 -> 939 bytes ...odelTester.test_efficientnet_b5_expect.pkl | Bin 0 -> 939 bytes ...odelTester.test_efficientnet_b6_expect.pkl | Bin 0 -> 939 bytes ...odelTester.test_efficientnet_b7_expect.pkl | Bin 0 -> 939 bytes torchvision/models/efficientnet.py | 98 ++++++++++++++++-- 7 files changed, 92 insertions(+), 6 deletions(-) create mode 100644 test/expect/ModelTester.test_efficientnet_b1_expect.pkl create mode 100644 test/expect/ModelTester.test_efficientnet_b2_expect.pkl create mode 100644 test/expect/ModelTester.test_efficientnet_b4_expect.pkl create mode 100644 test/expect/ModelTester.test_efficientnet_b5_expect.pkl create mode 100644 test/expect/ModelTester.test_efficientnet_b6_expect.pkl create mode 100644 test/expect/ModelTester.test_efficientnet_b7_expect.pkl diff --git a/test/expect/ModelTester.test_efficientnet_b1_expect.pkl b/test/expect/ModelTester.test_efficientnet_b1_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7a3731002e383cc020a76ac8782dfea67b0dce9f GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5@(uU`^uw>E@;L?Z(Y%0!>74=wce^ZIeMBMQPWp#^J-bu_tREOo}FoBmA2R_ zCGIJjR!l!M_#)?MGHA&!ojS=_YsZzC6=GT6GR;Lm&AeHkHSJ%W*PQ>& zODjIyX_Zk*m1e}n#VdBbyR9K{GI^DKzq}ShY>(DT&V?%{ob%9(T{mHsRa>&wtVd~D zl5SdBr~Xy00) zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5-0CNzNX}o3{4l#oti>ZPOMzaAE;U6SGwwg-vag3mW7(}(z%-RbLE7$jv5yJUsmpVEUHx?(x>6@ooiK`Pv=UzFKk-1Z(21ko%*w);DVIaa!W?7 zN6wuqcG-z*fkNx_u`Sn507D9dafcQ^gEc&~%2JDpfwAD`WKLvoA%z@-Y0L$(`Qp6v zP$r zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK636g`=!%c0?=PRE6tm*isa%Z}mFJgFnDJs6zk9fbxBGGRt6h~V)^YoUU8u2+46+bHEJK$EYV;&xJkpVDp+IM-w=(u(1jWNbuob`Lr=~IseVr8KDV0eaCx9UZ!ni|=pTQa)T4kw4#lTo_b229~xR62)!ZhXr*?e(c zdMFdnRuB&GW&~02G>IIC0w4(#fSy9pbtC(U4@KuIAP-r$z5%*kWLNQ{=#>Dv5T+Lz z1_9n|Y&uXya?HAL<)Fk20x)_zgv&4q>`9P!*+6-N!4s+glnDa7S=m5h%s>cI4^ayM D(unp3 literal 0 HcmV?d00001 diff --git a/test/expect/ModelTester.test_efficientnet_b5_expect.pkl b/test/expect/ModelTester.test_efficientnet_b5_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7c674259cd99d2663d7a85d1b74042e93bc40a1e GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK66d#r>56sJ&S*R+oxft`o>?nReQ#-K#dT;*IU=KeY{@2#7pI#ww64TzG)1bf z^pM}Dp>rTo!^2rr^H|I_js5X&HN4ELG<4DfS6GT4*9g)-t)%}h80Eu$Cvwv zwP|R~ch&qqalQsygNw$WmjW7mO*hrE-gK^vdT?#U-;c935{g?_$af`Z*ce|~5%k$r z-LCAE#w_*^E9zhAYGi2osO8;`(>!x>-txTXZki38Jj*|prmDYLb!Yj>y>m2nttwq{ z+;zLgtLNvJgF;Iuqqg(}Fr+{jcWCi5Si?iBEVZZ@7z=Jr=0pY;QpiD=##|tqFV0I3 zWdhm?!U5ilAPSx)k>gMRB!L3ZQz*J_WIyqt=zIm_A?wyRK-Y`xDt;8b5 zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5=Y_3bG6o+A}hY|ny%RLJyzo%i<|~m{wj?#m%l8%aDD%ZY4XbID^Cfom=ir= z#j2_Q)DQGjFWcTCydp%~Tw~AIIQ6={tr{YV50}j|x}x!JN9Br)Fq0Kk-@DZ_FD_qR zah`w0AFaL>3_2fFzj4|wf04Lo>8fuR)Z64HtMja!w0wc?k!7u+5$ZApEz93eb6>u4 z?ph7Y-)S0OANDOTtYy@Y$kbVWHT;@-z=^`;NzZ%K`3@ghv7#?iL+^R^a*vgB)fNcE zsa|$opwV#n#u89yHO8NHI{^$S5XK!^{0!Fc&?-wUDh9@ao0B<_!G#oZ5T-E~$mWal z(nFbmwt{egHzSCGr%B{E6aYz}0Q3}!t{d4;d?-3!0eQ%}^$pPVBD;zoMXv zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK66f2#>&qwI;$Lpnv`Ia40_zH`Wabr~)w5T$-CTQ5p>vS#q@6>8VfGesIDp9ygcE?UX46s|K;Jf*2^n4GcJ3PsHFa-`t~xB zIg;uTPBm&(3TM=K!UWXW&7Ll^^4zVyW&ieNr#a6oUo_Qkg=OWh EfficientNet: + """ + Constructs a EfficientNet B1 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.0, depth_mult=1.1, **kwargs) + return _efficientnet_model("efficientnet_b1", inverted_residual_setting, 0.2, pretrained, progress, **kwargs) + + +def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B2 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.1, depth_mult=1.2, **kwargs) + return _efficientnet_model("efficientnet_b2", inverted_residual_setting, 0.3, pretrained, progress, **kwargs) + + def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: """ Constructs a EfficientNet B3 architecture from @@ -278,3 +312,55 @@ def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: A """ inverted_residual_setting = _efficientnet_conf(width_mult=1.2, depth_mult=1.4, **kwargs) return _efficientnet_model("efficientnet_b3", inverted_residual_setting, 0.3, pretrained, progress, **kwargs) + + +def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B4 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.4, depth_mult=1.8, **kwargs) + return _efficientnet_model("efficientnet_b4", inverted_residual_setting, 0.4, pretrained, progress, **kwargs) + + +def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B5 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.6, depth_mult=2.2, **kwargs) + return _efficientnet_model("efficientnet_b5", inverted_residual_setting, 0.4, pretrained, progress, **kwargs) + + +def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B6 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.8, depth_mult=2.6, **kwargs) + return _efficientnet_model("efficientnet_b6", inverted_residual_setting, 0.5, pretrained, progress, **kwargs) + + +def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B7 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=2.0, depth_mult=3.1, **kwargs) + return _efficientnet_model("efficientnet_b7", inverted_residual_setting, 0.5, pretrained, progress, **kwargs) From d2bfd639e46e1c5dc3c177f889dc7750c8d137c7 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 Aug 2021 11:47:50 +0100 Subject: [PATCH 15/19] Update weights. --- references/classification/train.py | 6 +- ...odelTester.test_efficientnet_b0_expect.pkl | Bin 939 -> 939 bytes ...odelTester.test_efficientnet_b1_expect.pkl | Bin 939 -> 939 bytes ...odelTester.test_efficientnet_b2_expect.pkl | Bin 939 -> 939 bytes ...odelTester.test_efficientnet_b3_expect.pkl | Bin 939 -> 939 bytes ...odelTester.test_efficientnet_b4_expect.pkl | Bin 939 -> 939 bytes ...odelTester.test_efficientnet_b5_expect.pkl | Bin 939 -> 0 bytes ...odelTester.test_efficientnet_b6_expect.pkl | Bin 939 -> 0 bytes ...odelTester.test_efficientnet_b7_expect.pkl | Bin 939 -> 0 bytes torchvision/models/efficientnet.py | 58 +++--------------- 10 files changed, 11 insertions(+), 53 deletions(-) delete mode 100644 test/expect/ModelTester.test_efficientnet_b5_expect.pkl delete mode 100644 test/expect/ModelTester.test_efficientnet_b6_expect.pkl delete mode 100644 test/expect/ModelTester.test_efficientnet_b7_expect.pkl diff --git a/references/classification/train.py b/references/classification/train.py index a7dcce21411..750dd177252 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -89,11 +89,11 @@ def load_data(traindir, valdir, args): resize_size, crop_size = 342, 299 elif args.model.startswith('efficientnet_'): sizes = { - 'b0': 224, 'b1': 240, 'b2': 260, 'b3': 300, - 'b4': 380, 'b5': 456, 'b6': 528, 'b7': 600, + 'b0': (256, 224), 'b1': (256, 240), 'b2': (288, 288), 'b3': (320, 300), + 'b4': (384, 380), 'b5': (489, 456), 'b6': (561, 528), 'b7': (633, 600), } e_type = args.model.replace('efficientnet_', '') - resize_size = crop_size = sizes[e_type] + resize_size, crop_size = sizes[e_type] interpolation = InterpolationMode.BICUBIC print("Loading training data") diff --git a/test/expect/ModelTester.test_efficientnet_b0_expect.pkl b/test/expect/ModelTester.test_efficientnet_b0_expect.pkl index c299eac15af0db2a6f55347c57e82fd168bc47e9..1de871ce0fbea9ddbab7e315b05f864bc5f6fa53 100644 GIT binary patch delta 230 zcmV`f{_4yvc`Y`Q4`Qd6q3>)c!!QD*YE%;hw;{_ExaeH*ySmzbUrMqzRPv0+ff0j gL>IECEbHv3P)i30F~yPIlMn*X1Tn>t-jn14#}|)o3jhEB delta 230 zcmVlF`cSG0l-?*BR1zwNZ{{h?n-r-40Y@rQ ztY@lIyLPHD=bI>)@kT0&9SSOdEmEezS&k`kwu31+yyK}|^(3j#&=IIC$o(mt`_ZV-m{F+p=q;*#VJ%TAp!~(B@fg~vqt|t+ zhXBT@9`3U#@oe=d)5)qSd|FYeE&q%ugcJWM#Ff;kfb1ozjRh#G##<_?4jXf-|UrgGx;eHYJ{q=cuXqFE{!OrZYU~1 zdUh!S4v?vx$-b$%-NPwxn;EDAM2;z9E#)bs^h+sajJv3lt1wZjsO#z}DPs01&^OX4 zr7}$_4t+$b78rdgev_rB7AM{)QAcK~=ujalGUtsd)zG4;F#Ad=szj2i`3Pnz&|+yS gmoh9W8Nz(3P)i30FFa$plMn*X1TQ>exs&7q$3yaDLI3~& delta 230 zcmVM}N`$J8`8_Hh z2Lh>jD;BCE3zR830{AEnS)3^VDj%hlk~1o})Lf_*YW69(`huzzK}sq-vx_JTnBUspn^GGuxwSka@Xy6)R37|CO*KaU?O09}kKsR^N}kj_XcU9OO-I)-B^n&M_E g8Ad89%Kmw(P)i30Sd-$wlMn*X1Xz>ez?0+x$7)t*XaE2J diff --git a/test/expect/ModelTester.test_efficientnet_b2_expect.pkl b/test/expect/ModelTester.test_efficientnet_b2_expect.pkl index 4cc5ed068bcaf8fefecae2889422fcbb17920259..f0aeb8ec122d5a350052ad5788918607cfb0cf91 100644 GIT binary patch delta 230 zcmVYCy0(h$6*r6vz;BzSut8Xb0 zy6~yM-##i~WzMO}fhek@2?i=2dFZH$-A$=PgEA_El!2+#a>%JHlbNWD*!w8<3_B|R z1*)l@AqFbJ$Kxr8{3|M;Qzxe#u)C<-aKR|P()y_w=_4w50}WBC*e{5w5gv`HxUT4^ zK$X2I>V?~>R1HKZ&<6LZqKg+Q46TnS4Sx)(m&uB$0c{8>L}Z64YnA_~SBD%bJf;FF gSCxvWE$$d9P)i30a_5~!lMn*X1ajw{Mw8?M$C%M+fdBvi delta 230 zcmVMaHRw%$KO>Ao(bJ134-r zEvTu&tOP22Q{gGG94ab!>m{c;g}JEcD!(Wr75J$TA|fj5cMDOfU$%y+tdonWBt+(@ zo`t(9{w3L|knKSzI8XGcy5kipa1@RxK=%r&UQCIpKJ*4Eee8uP(#rg(aL^hmr#S&C g;zNn3x;_^wP)i30%f__UlMn*X1k1*>)|2D{$04q1W&i*H diff --git a/test/expect/ModelTester.test_efficientnet_b3_expect.pkl b/test/expect/ModelTester.test_efficientnet_b3_expect.pkl index a93d3a5612729867bacdc1b65a183f6ee8dcb654..989d6782fe799c4833239c51f08e8375a6592179 100644 GIT binary patch delta 230 zcmV{oLTLmfP%U-DVNW`fu<1;Dob)>0y1{|siXeX+G z@1Lj@43#PPN0Og9&1$S`B5_J!^#cx@V~ud~3ZswD*}(jvI1Fhjqn^CVU%;Q#X}N5j`C+k}Uyz)Y#8 ga`w0=miVKoP)i306e?e_lMn*X1QaS?v6JKi$F9_DE&u=k delta 230 zcmVH{e!ZCt2JF~X^f?=dMetD>osRT`=dmnEtK zDxRorZj&iuv5=|!zEmkvC{L)=35uz}jNm8Iuf(aK0;(lG|Bop!Y8WZ{bsnj8uuG}V zU|6YO&kHA{_n4{qi&H7EhYqRS|468;9Hgn(;!3EVwTLK~bHq`n7N*ImlDc>)z(ND2 zmTZNojHO_xtn%wA{RP)i30&ZDhHlMn*X1kR(aMw8?M$KkebWdHyG diff --git a/test/expect/ModelTester.test_efficientnet_b4_expect.pkl b/test/expect/ModelTester.test_efficientnet_b4_expect.pkl index 3ac26865c50349717e73ff5f1befba55a7e23ced..f4a0cc04bf0ff3eee9f110f43a4c39a530c51f5e 100644 GIT binary patch delta 230 zcmV;pb#N#8tBrq*<* gMiq&t=4U7U zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK66d#r>56sJ&S*R+oxft`o>?nReQ#-K#dT;*IU=KeY{@2#7pI#ww64TzG)1bf z^pM}Dp>rTo!^2rr^H|I_js5X&HN4ELG<4DfS6GT4*9g)-t)%}h80Eu$Cvwv zwP|R~ch&qqalQsygNw$WmjW7mO*hrE-gK^vdT?#U-;c935{g?_$af`Z*ce|~5%k$r z-LCAE#w_*^E9zhAYGi2osO8;`(>!x>-txTXZki38Jj*|prmDYLb!Yj>y>m2nttwq{ z+;zLgtLNvJgF;Iuqqg(}Fr+{jcWCi5Si?iBEVZZ@7z=Jr=0pY;QpiD=##|tqFV0I3 zWdhm?!U5ilAPSx)k>gMRB!L3ZQz*J_WIyqt=zIm_A?wyRK-Y`xDt;8b5 zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5=Y_3bG6o+A}hY|ny%RLJyzo%i<|~m{wj?#m%l8%aDD%ZY4XbID^Cfom=ir= z#j2_Q)DQGjFWcTCydp%~Tw~AIIQ6={tr{YV50}j|x}x!JN9Br)Fq0Kk-@DZ_FD_qR zah`w0AFaL>3_2fFzj4|wf04Lo>8fuR)Z64HtMja!w0wc?k!7u+5$ZApEz93eb6>u4 z?ph7Y-)S0OANDOTtYy@Y$kbVWHT;@-z=^`;NzZ%K`3@ghv7#?iL+^R^a*vgB)fNcE zsa|$opwV#n#u89yHO8NHI{^$S5XK!^{0!Fc&?-wUDh9@ao0B<_!G#oZ5T-E~$mWal z(nFbmwt{egHzSCGr%B{E6aYz}0Q3}!t{d4;d?-3!0eQ%}^$pPVBD;zoMXv zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK66f2#>&qwI;$Lpnv`Ia40_zH`Wabr~)w5T$-CTQ5p>vS#q@6>8VfGesIDp9ygcE?UX46s|K;Jf*2^n4GcJ3PsHFa-`t~xB zIg;uTPBm&(3TM=K!UWXW&7Ll^^4zVyW&ieNr#a6oUo_Qkg=OWh EfficientNet: - """ - Constructs a EfficientNet B5 architecture from - `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - inverted_residual_setting = _efficientnet_conf(width_mult=1.6, depth_mult=2.2, **kwargs) - return _efficientnet_model("efficientnet_b5", inverted_residual_setting, 0.4, pretrained, progress, **kwargs) - - -def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: - """ - Constructs a EfficientNet B6 architecture from - `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - inverted_residual_setting = _efficientnet_conf(width_mult=1.8, depth_mult=2.6, **kwargs) - return _efficientnet_model("efficientnet_b6", inverted_residual_setting, 0.5, pretrained, progress, **kwargs) - - -def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: - """ - Constructs a EfficientNet B7 architecture from - `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - inverted_residual_setting = _efficientnet_conf(width_mult=2.0, depth_mult=3.1, **kwargs) - return _efficientnet_model("efficientnet_b7", inverted_residual_setting, 0.5, pretrained, progress, **kwargs) From 8330faba19fa249120b315e0a8faad227fc00215 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 Aug 2021 13:16:45 +0100 Subject: [PATCH 16/19] Adding B5-B7 weights. --- references/classification/train.py | 2 +- ...odelTester.test_efficientnet_b5_expect.pkl | Bin 0 -> 939 bytes ...odelTester.test_efficientnet_b6_expect.pkl | Bin 0 -> 939 bytes ...odelTester.test_efficientnet_b7_expect.pkl | Bin 0 -> 939 bytes torchvision/models/efficientnet.py | 48 +++++++++++++++++- 5 files changed, 48 insertions(+), 2 deletions(-) create mode 100644 test/expect/ModelTester.test_efficientnet_b5_expect.pkl create mode 100644 test/expect/ModelTester.test_efficientnet_b6_expect.pkl create mode 100644 test/expect/ModelTester.test_efficientnet_b7_expect.pkl diff --git a/references/classification/train.py b/references/classification/train.py index 750dd177252..9ba99b3dc54 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -90,7 +90,7 @@ def load_data(traindir, valdir, args): elif args.model.startswith('efficientnet_'): sizes = { 'b0': (256, 224), 'b1': (256, 240), 'b2': (288, 288), 'b3': (320, 300), - 'b4': (384, 380), 'b5': (489, 456), 'b6': (561, 528), 'b7': (633, 600), + 'b4': (384, 380), 'b5': (456, 456), 'b6': (528, 528), 'b7': (600, 600), } e_type = args.model.replace('efficientnet_', '') resize_size, crop_size = sizes[e_type] diff --git a/test/expect/ModelTester.test_efficientnet_b5_expect.pkl b/test/expect/ModelTester.test_efficientnet_b5_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..7c674259cd99d2663d7a85d1b74042e93bc40a1e GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK66d#r>56sJ&S*R+oxft`o>?nReQ#-K#dT;*IU=KeY{@2#7pI#ww64TzG)1bf z^pM}Dp>rTo!^2rr^H|I_js5X&HN4ELG<4DfS6GT4*9g)-t)%}h80Eu$Cvwv zwP|R~ch&qqalQsygNw$WmjW7mO*hrE-gK^vdT?#U-;c935{g?_$af`Z*ce|~5%k$r z-LCAE#w_*^E9zhAYGi2osO8;`(>!x>-txTXZki38Jj*|prmDYLb!Yj>y>m2nttwq{ z+;zLgtLNvJgF;Iuqqg(}Fr+{jcWCi5Si?iBEVZZ@7z=Jr=0pY;QpiD=##|tqFV0I3 zWdhm?!U5ilAPSx)k>gMRB!L3ZQz*J_WIyqt=zIm_A?wyRK-Y`xDt;8b5 zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5=Y_3bG6o+A}hY|ny%RLJyzo%i<|~m{wj?#m%l8%aDD%ZY4XbID^Cfom=ir= z#j2_Q)DQGjFWcTCydp%~Tw~AIIQ6={tr{YV50}j|x}x!JN9Br)Fq0Kk-@DZ_FD_qR zah`w0AFaL>3_2fFzj4|wf04Lo>8fuR)Z64HtMja!w0wc?k!7u+5$ZApEz93eb6>u4 z?ph7Y-)S0OANDOTtYy@Y$kbVWHT;@-z=^`;NzZ%K`3@ghv7#?iL+^R^a*vgB)fNcE zsa|$opwV#n#u89yHO8NHI{^$S5XK!^{0!Fc&?-wUDh9@ao0B<_!G#oZ5T-E~$mWal z(nFbmwt{egHzSCGr%B{E6aYz}0Q3}!t{d4;d?-3!0eQ%}^$pPVBD;zoMXv zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK66f2#>&qwI;$Lpnv`Ia40_zH`Wabr~)w5T$-CTQ5p>vS#q@6>8VfGesIDp9ygcE?UX46s|K;Jf*2^n4GcJ3PsHFa-`t~xB zIg;uTPBm&(3TM=K!UWXW&7Ll^^4zVyW&ieNr#a6oUo_Qkg=OWh EfficientNet: + """ + Constructs a EfficientNet B5 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.6, depth_mult=2.2, **kwargs) + return _efficientnet_model("efficientnet_b5", inverted_residual_setting, 0.4, pretrained, progress, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs) + + +def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B6 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=1.8, depth_mult=2.6, **kwargs) + return _efficientnet_model("efficientnet_b6", inverted_residual_setting, 0.5, pretrained, progress, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs) + + +def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: + """ + Constructs a EfficientNet B7 architecture from + `"EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + inverted_residual_setting = _efficientnet_conf(width_mult=2.0, depth_mult=3.1, **kwargs) + return _efficientnet_model("efficientnet_b7", inverted_residual_setting, 0.5, pretrained, progress, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs) From 901b282c92a71b1e3b1746188337c7d7434ebea8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 Aug 2021 14:22:36 +0100 Subject: [PATCH 17/19] Update docs and hubconf. --- docs/source/models.rst | 47 +++++++++++++++++++++++++++-- hubconf.py | 2 ++ references/classification/README.md | 6 ++++ torchvision/models/efficientnet.py | 17 +++++------ 4 files changed, 60 insertions(+), 12 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index b9bff7a36e8..a9d539c3a73 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -19,14 +19,15 @@ architectures for image classification: - `ResNet`_ - `SqueezeNet`_ - `DenseNet`_ -- `Inception`_ v3 +- `InceptionV3`_ - `GoogLeNet`_ -- `ShuffleNet`_ v2 +- `ShuffleNetV2`_ - `MobileNetV2`_ - `MobileNetV3`_ - `ResNeXt`_ - `Wide ResNet`_ - `MNASNet`_ +- `EfficientNet`_ You can construct a model with random weights by calling its constructor: @@ -47,6 +48,14 @@ You can construct a model with random weights by calling its constructor: resnext50_32x4d = models.resnext50_32x4d() wide_resnet50_2 = models.wide_resnet50_2() mnasnet = models.mnasnet1_0() + efficientnet_b0 = models.efficientnet_b0() + efficientnet_b1 = models.efficientnet_b1() + efficientnet_b2 = models.efficientnet_b2() + efficientnet_b3 = models.efficientnet_b3() + efficientnet_b4 = models.efficientnet_b4() + efficientnet_b5 = models.efficientnet_b5() + efficientnet_b6 = models.efficientnet_b6() + efficientnet_b7 = models.efficientnet_b7() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -68,6 +77,14 @@ These can be constructed by passing ``pretrained=True``: resnext50_32x4d = models.resnext50_32x4d(pretrained=True) wide_resnet50_2 = models.wide_resnet50_2(pretrained=True) mnasnet = models.mnasnet1_0(pretrained=True) + efficientnet_b0 = models.efficientnet_b0(pretrained=True) + efficientnet_b1 = models.efficientnet_b1(pretrained=True) + efficientnet_b2 = models.efficientnet_b2(pretrained=True) + efficientnet_b3 = models.efficientnet_b3(pretrained=True) + efficientnet_b4 = models.efficientnet_b4(pretrained=True) + efficientnet_b5 = models.efficientnet_b5(pretrained=True) + efficientnet_b6 = models.efficientnet_b6(pretrained=True) + efficientnet_b7 = models.efficientnet_b7(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_MODEL_ZOO` environment variable. See @@ -113,7 +130,10 @@ Unfortunately, the concrete `subset` that was used is lost. For more information see `this discussion `_ or `these experiments `_. -ImageNet 1-crop error rates (224x224) +The sizes of the EfficientNet models depends on the variant. For the exact configuration check +`here `_ + +ImageNet 1-crop error rates ================================ ============= ============= Model Acc@1 Acc@5 @@ -151,6 +171,14 @@ Wide ResNet-50-2 78.468 94.086 Wide ResNet-101-2 78.848 94.284 MNASNet 1.0 73.456 91.510 MNASNet 0.5 67.734 87.490 +EfficientNet-B0 77.692 93.532 +EfficientNet-B1 78.642 94.186 +EfficientNet-B2 80.608 95.310 +EfficientNet-B3 82.008 96.054 +EfficientNet-B4 83.384 96.594 +EfficientNet-B5 83.444 96.628 +EfficientNet-B6 84.008 96.916 +EfficientNet-B7 84.122 96.908 ================================ ============= ============= @@ -166,6 +194,7 @@ MNASNet 0.5 67.734 87.490 .. _MobileNetV3: https://arxiv.org/abs/1905.02244 .. _ResNeXt: https://arxiv.org/abs/1611.05431 .. _MNASNet: https://arxiv.org/abs/1807.11626 +.. _EfficientNet: https://arxiv.org/abs/1905.11946 .. currentmodule:: torchvision.models @@ -267,6 +296,18 @@ MNASNet .. autofunction:: mnasnet1_0 .. autofunction:: mnasnet1_3 +EfficientNet +------------ + +.. autofunction:: efficientnet_b0 +.. autofunction:: efficientnet_b1 +.. autofunction:: efficientnet_b2 +.. autofunction:: efficientnet_b3 +.. autofunction:: efficientnet_b4 +.. autofunction:: efficientnet_b5 +.. autofunction:: efficientnet_b6 +.. autofunction:: efficientnet_b7 + Quantized Models ---------------- diff --git a/hubconf.py b/hubconf.py index 097759bdd89..2bff6850525 100644 --- a/hubconf.py +++ b/hubconf.py @@ -15,6 +15,8 @@ from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \ mnasnet1_3 +from torchvision.models.efficientnet import efficientnet_b0, efficientnet_b1, efficientnet_b2, \ + efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7 # segmentation from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \ diff --git a/references/classification/README.md b/references/classification/README.md index e0b7f210175..210a63c0bca 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -68,6 +68,12 @@ Then we averaged the parameters of the last 3 checkpoints that improved the Acc@ and [#3354](https://github.com/pytorch/vision/pull/3354) for details. +### EfficientNet + +The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](https://github.com/rwightman/pytorch-image-models/blob/01cb46a9a50e3ba4be167965b5764e9702f09b30/timm/models/efficientnet.py#L95-L108). + +The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564). + ## Mixed precision training Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [NVIDIA Apex extension](https://github.com/NVIDIA/apex). diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 14b8d4e09e0..dd3550569e8 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -10,7 +10,6 @@ from .._internally_replaced_utils import load_state_dict_from_url from torchvision.ops import StochasticDepth -# TODO: refactor this to a common place? from torchvision.models.mobilenetv2 import ConvBNActivation, _make_divisible @@ -20,15 +19,15 @@ model_urls = { # Weights ported from https://github.com/rwightman/pytorch-image-models/ - "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0-rwightman.pth", - "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1-rwightman.pth", - "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2-rwightman.pth", - "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3-rwightman.pth", - "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4-rwightman.pth", + "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", + "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", + "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", + "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", + "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/ - "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5-lukemelas.pth", - "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6-lukemelas.pth", - "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7-lukemelas.pth", + "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", + "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", + "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", } From 7f8dae35206b659fa3e34eebb1159f3526f3e2ac Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 25 Aug 2021 16:42:38 +0100 Subject: [PATCH 18/19] Fix doc link. --- docs/source/models.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index a9d539c3a73..64ca69f47ae 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -19,9 +19,9 @@ architectures for image classification: - `ResNet`_ - `SqueezeNet`_ - `DenseNet`_ -- `InceptionV3`_ +- `Inception`_ v3 - `GoogLeNet`_ -- `ShuffleNetV2`_ +- `ShuffleNet`_ v2 - `MobileNetV2`_ - `MobileNetV3`_ - `ResNeXt`_ @@ -130,8 +130,8 @@ Unfortunately, the concrete `subset` that was used is lost. For more information see `this discussion `_ or `these experiments `_. -The sizes of the EfficientNet models depends on the variant. For the exact configuration check -`here `_ +The sizes of the EfficientNet models depend on the variant. For the exact input sizes +`check here `_ ImageNet 1-crop error rates From 210b3e2918e07da3f5e67e0c309a1d5fcae53929 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 26 Aug 2021 11:03:00 +0100 Subject: [PATCH 19/19] Fix typo on comment. --- torchvision/models/efficientnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index dd3550569e8..06b2a301b6d 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -50,7 +50,7 @@ def forward(self, input: Tensor) -> Tensor: class MBConvConfig: - # Stores information listed at Tables 1 of the EfficientNet paper + # Stores information listed at Table 1 of the EfficientNet paper def __init__(self, expand_ratio: float, kernel: int, stride: int, input_channels: int, out_channels: int, num_layers: int,