From a24b7c10d8b7e3e88a510b09d827d1cb2e6a2c32 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 12 May 2021 16:11:20 +0100 Subject: [PATCH] Cerrypicking cleanups for SSD and SSDlite. --- docs/source/models.rst | 14 +++++------ references/detection/README.md | 4 ++-- torchvision/models/detection/ssd.py | 31 ++++++++++++++----------- torchvision/models/detection/ssdlite.py | 30 ++++++++++++------------ 4 files changed, 41 insertions(+), 38 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index f9fb793ed36..c2b81e49735 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -426,8 +426,8 @@ Faster R-CNN ResNet-50 FPN 37.0 - - Faster R-CNN MobileNetV3-Large FPN 32.8 - - Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - - RetinaNet ResNet-50 FPN 36.4 - - -SSD VGG16 25.1 - - -SSDlite MobileNetV3-Large 21.3 - - +SSD300 VGG16 25.1 - - +SSDlite320 MobileNetV3-Large 21.3 - - Mask R-CNN ResNet-50 FPN 37.9 34.6 - ====================================== ======= ======== =========== @@ -486,8 +486,8 @@ Faster R-CNN ResNet-50 FPN 0.2288 0.0590 Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0 Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6 RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1 -SSD VGG16 0.2093 0.0744 1.5 -SSDlite MobileNetV3-Large 0.1773 0.0906 1.5 +SSD300 VGG16 0.2093 0.0744 1.5 +SSDlite320 MobileNetV3-Large 0.1773 0.0906 1.5 Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4 Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8 ====================================== =================== ================== =========== @@ -502,19 +502,19 @@ Faster R-CNN RetinaNet ------------- +--------- .. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn SSD ------------- +--- .. autofunction:: torchvision.models.detection.ssd300_vgg16 SSDlite ------------- +------- .. autofunction:: torchvision.models.detection.ssdlite320_mobilenet_v3_large diff --git a/references/detection/README.md b/references/detection/README.md index 2fb0b658aa7..ea5be6ea791 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -48,7 +48,7 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 ``` -### SSD VGG16 +### SSD300 VGG16 ``` python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ --dataset coco --model ssd300_vgg16 --epochs 120\ @@ -56,7 +56,7 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ --weight-decay 0.0005 --data-augmentation ssd ``` -### SSDlite MobileNetV3-Large +### SSDlite320 MobileNetV3-Large ``` python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ --dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\ diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index aeb93012d78..f6150cf5cd5 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -410,7 +410,7 @@ def postprocess_detections(self, head_outputs: Dict[str, Tensor], image_anchors: class SSDFeatureExtractorVGG(nn.Module): - def __init__(self, backbone: nn.Module, highres: bool, rescaling: bool): + def __init__(self, backbone: nn.Module, highres: bool): super().__init__() _, _, maxpool3_pos, maxpool4_pos, _ = (i for i, layer in enumerate(backbone) if isinstance(layer, nn.MaxPool2d)) @@ -476,13 +476,8 @@ def __init__(self, backbone: nn.Module, highres: bool, rescaling: bool): fc, )) self.extra = extra - self.rescaling = rescaling def forward(self, x: Tensor) -> Dict[str, Tensor]: - # Undo the 0-1 scaling of toTensor. Necessary for some backbones. - if self.rescaling: - x *= 255 - # L2 regularization + Rescaling of 1st block's feature map x = self.features(x) rescaled = self.scale_weight.view(1, -1, 1, 1) * F.normalize(x) @@ -496,8 +491,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return OrderedDict([(str(i), v) for i, v in enumerate(output)]) -def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int, - rescaling: bool): +def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int): if backbone_name in backbone_urls: # Use custom backbones more appropriate for SSD arch = backbone_name.split('_')[0] @@ -521,19 +515,19 @@ def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained for parameter in b.parameters(): parameter.requires_grad_(False) - return SSDFeatureExtractorVGG(backbone, highres, rescaling) + return SSDFeatureExtractorVGG(backbone, highres) def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: int = 91, pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any): """ - Constructs an SSD model with a VGG16 backbone. See `SSD` for more details. + Constructs an SSD model with input size 300x300 and a VGG16 backbone. See `SSD` for more details. Example: >>> model = torchvision.models.detection.ssd300_vgg16(pretrained=True) >>> model.eval() - >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> x = [torch.rand(3, 300, 300), torch.rand(3, 500, 400)] >>> predictions = model(x) Args: @@ -544,6 +538,9 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ + if "size" in kwargs: + warnings.warn("The size of the model is already fixed; ignoring the argument.") + trainable_backbone_layers = _validate_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5) @@ -551,12 +548,18 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers, True) + backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers) anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]], scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], steps=[8, 16, 32, 64, 100, 300]) - model = SSD(backbone, anchor_generator, (300, 300), num_classes, - image_mean=[0.48235, 0.45882, 0.40784], image_std=[1., 1., 1.], **kwargs) + + defaults = { + # Rescale the input in a way compatible to the backbone + "image_mean": [0.48235, 0.45882, 0.40784], + "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor + } + kwargs = {**defaults, **kwargs} + model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) if pretrained: weights_name = 'ssd300_vgg16_coco' if model_urls.get(weights_name, None) is None: diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 412434dabd7..8498a78d6dd 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -1,4 +1,5 @@ import torch +import warnings from collections import OrderedDict from functools import partial @@ -94,8 +95,7 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: C class SSDLiteFeatureExtractorMobileNet(nn.Module): - def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], rescaling: bool, - **kwargs: Any): + def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], **kwargs: Any): super().__init__() # non-public config parameters min_depth = kwargs.pop('_min_depth', 16) @@ -117,13 +117,8 @@ def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., n _normal_init(extra) self.extra = extra - self.rescaling = rescaling def forward(self, x: Tensor) -> Dict[str, Tensor]: - # Rescale from [0, 1] to [-1, -1] - if self.rescaling: - x = 2.0 * x - 1.0 - # Get feature maps from backbone and extra. Can't be refactored due to JIT limitations. output = [] for block in self.features: @@ -138,7 +133,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int, - norm_layer: Callable[..., nn.Module], rescaling: bool, **kwargs: Any): + norm_layer: Callable[..., nn.Module], **kwargs: Any): backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress, norm_layer=norm_layer, **kwargs).features if not pretrained: @@ -158,7 +153,7 @@ def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, t for parameter in b.parameters(): parameter.requires_grad_(False) - return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, rescaling, **kwargs) + return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs) def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91, @@ -166,7 +161,7 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru norm_layer: Optional[Callable[..., nn.Module]] = None, **kwargs: Any): """ - Constructs an SSDlite model with a MobileNetV3 Large backbone. See `SSD` for more details. + Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone. See `SSD` for more details. Example: @@ -186,20 +181,23 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. norm_layer (callable, optional): Module specifying the normalization layer to use. """ + if "size" in kwargs: + warnings.warn("The size of the model is already fixed; ignoring the argument.") + trainable_backbone_layers = _validate_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6) if pretrained: pretrained_backbone = False - # Enable [-1, 1] rescaling and reduced tail if no pretrained backbone is selected - rescaling = reduce_tail = not pretrained_backbone + # Enable reduced tail if no pretrained backbone is selected + reduce_tail = not pretrained_backbone if norm_layer is None: norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers, - norm_layer, rescaling, _reduced_tail=reduce_tail, _width_mult=1.0) + norm_layer, _reduced_tail=reduce_tail, _width_mult=1.0) size = (320, 320) anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95) @@ -212,8 +210,10 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru "nms_thresh": 0.55, "detections_per_img": 300, "topk_candidates": 300, - "image_mean": [0., 0., 0.], - "image_std": [1., 1., 1.], + # Rescale the input in a way compatible to the backbone: + # The following mean/std rescale the data from [0, 1] to [-1, -1] + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5], } kwargs = {**defaults, **kwargs} model = SSD(backbone, anchor_generator, size, num_classes,