diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 4dde47d40c3..06baea35fa8 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -74,7 +74,7 @@ def test_get_weight(name, weight): @pytest.mark.parametrize( "model_fn", TM.get_models_from_module(torchvision.models) - + TM.get_models_from_module(models.detection) + + TM.get_models_from_module(torchvision.models.detection) + TM.get_models_from_module(torchvision.models.quantization) + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.video) @@ -90,7 +90,7 @@ def test_naming_conventions(model_fn): @pytest.mark.parametrize( "model_fn", TM.get_models_from_module(torchvision.models) - + TM.get_models_from_module(models.detection) + + TM.get_models_from_module(torchvision.models.detection) + TM.get_models_from_module(torchvision.models.quantization) + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.video) @@ -143,13 +143,6 @@ def test_schema_meta_validation(model_fn): assert not bad_names -@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection)) -@pytest.mark.parametrize("dev", cpu_and_gpu()) -@run_if_test_with_prototype -def test_detection_model(model_fn, dev): - TM.test_detection_model(model_fn, dev) - - @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation)) @pytest.mark.parametrize("dev", cpu_and_gpu()) @run_if_test_with_prototype @@ -174,8 +167,7 @@ def test_raft(model_builder, scripted): @pytest.mark.parametrize( "model_fn", - TM.get_models_from_module(models.detection) - + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.segmentation) + TM.get_models_from_module(models.video) + TM.get_models_from_module(models.optical_flow), ) diff --git a/torchvision/models/detection/__init__.py b/torchvision/models/detection/__init__.py index be46f950a61..4146651c737 100644 --- a/torchvision/models/detection/__init__.py +++ b/torchvision/models/detection/__init__.py @@ -1,7 +1,7 @@ from .faster_rcnn import * -from .mask_rcnn import * +from .fcos import * from .keypoint_rcnn import * +from .mask_rcnn import * from .retinanet import * from .ssd import * from .ssdlite import * -from .fcos import * diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 5ac5f179479..cac96b61f64 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -88,7 +88,7 @@ def resnet_fpn_backbone( pretrained (bool): If True, returns a model with backbone pre-trained on Imagenet norm_layer (callable): it is recommended to use the default value. For details visit: (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267) - trainable_layers (int): number of trainable (not frozen) resnet layers starting from final block. + trainable_layers (int): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``. By default all layers are returned. diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 790740fe9c5..18872adc029 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -1,11 +1,16 @@ +from typing import Any, Optional, Union + import torch.nn.functional as F from torch import nn from torchvision.ops import MultiScaleRoIAlign -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import misc as misc_nn_ops -from ..mobilenetv3 import mobilenet_v3_large -from ..resnet import resnet50 +from ...transforms import ObjectDetectionEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large +from ..resnet import ResNet50_Weights, resnet50 from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, _mobilenet_extractor @@ -17,9 +22,12 @@ __all__ = [ "FasterRCNN", + "FasterRCNN_ResNet50_FPN_Weights", + "FasterRCNN_MobileNet_V3_Large_FPN_Weights", + "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights", "fasterrcnn_resnet50_fpn", - "fasterrcnn_mobilenet_v3_large_320_fpn", "fasterrcnn_mobilenet_v3_large_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn", ] @@ -307,16 +315,70 @@ def forward(self, x): return scores, bbox_deltas -model_urls = { - "fasterrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", - "fasterrcnn_mobilenet_v3_large_320_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", - "fasterrcnn_mobilenet_v3_large_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", +_COMMON_META = { + "task": "image_object_detection", + "architecture": "FasterRCNN", + "publication_year": 2015, + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, } +class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", + transforms=ObjectDetectionEval, + meta={ + **_COMMON_META, + "num_params": 41755286, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", + "map": 37.0, + }, + ) + DEFAULT = COCO_V1 + + +class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", + transforms=ObjectDetectionEval, + meta={ + **_COMMON_META, + "num_params": 19386354, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", + "map": 32.8, + }, + ) + DEFAULT = COCO_V1 + + +class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", + transforms=ObjectDetectionEval, + meta={ + **_COMMON_META, + "num_params": 19386354, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", + "map": 22.8, + }, + ) + DEFAULT = COCO_V1 + + +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def fasterrcnn_resnet50_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: """ Constructs a Faster R-CNN model with a ResNet-50-FPN backbone. @@ -375,51 +437,60 @@ def fasterrcnn_resnet50_fpn( >>> torch.onnx.export(model, x, "faster_rcnn.onnx", opset_version = 11) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FasterRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - is_trained = pretrained or pretrained_backbone + weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = FasterRCNN(backbone, num_classes, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress) - model.load_state_dict(state_dict) - overwrite_eps(model, 0.0) + model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + return model def _fasterrcnn_mobilenet_v3_large_fpn( - weights_name, - pretrained=False, - progress=True, - num_classes=91, - pretrained_backbone=True, - trainable_backbone_layers=None, - **kwargs, -): - is_trained = pretrained or pretrained_backbone + *, + weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]], + progress: bool, + num_classes: Optional[int], + weights_backbone: Optional[MobileNet_V3_Large_Weights], + trainable_backbone_layers: Optional[int], + **kwargs: Any, +) -> FasterRCNN: + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - pretrained_backbone = False - - backbone = mobilenet_v3_large(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) - anchor_sizes = ( ( 32, @@ -430,21 +501,29 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ), ) * 3 aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - model = FasterRCNN( backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs ) - if pretrained: - if model_urls.get(weights_name, None) is None: - raise ValueError(f"No checkpoint is available for model {weights_name}") - state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def fasterrcnn_mobilenet_v3_large_320_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: """ Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See @@ -459,15 +538,17 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FasterRCNN_MobileNet_V3_Large_320_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - weights_name = "fasterrcnn_mobilenet_v3_large_320_fpn_coco" + weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + defaults = { "min_size": 320, "max_size": 640, @@ -478,19 +559,28 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( kwargs = {**defaults, **kwargs} return _fasterrcnn_mobilenet_v3_large_fpn( - weights_name, - pretrained=pretrained, + weights=weights, progress=progress, num_classes=num_classes, - pretrained_backbone=pretrained_backbone, + weights_backbone=weights_backbone, trainable_backbone_layers=trainable_backbone_layers, **kwargs, ) +@handle_legacy_interface( + weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def fasterrcnn_mobilenet_v3_large_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: """ Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See @@ -505,26 +595,27 @@ def fasterrcnn_mobilenet_v3_large_fpn( >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FasterRCNN_MobileNet_V3_Large_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (MobileNet_V3_Large_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - weights_name = "fasterrcnn_mobilenet_v3_large_fpn_coco" + weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + defaults = { "rpn_score_thresh": 0.05, } kwargs = {**defaults, **kwargs} return _fasterrcnn_mobilenet_v3_large_fpn( - weights_name, - pretrained=pretrained, + weights=weights, progress=progress, num_classes=num_classes, - pretrained_backbone=pretrained_backbone, + weights_backbone=weights_backbone, trainable_backbone_layers=trainable_backbone_layers, **kwargs, ) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index c4c2e6f5842..8d110d809f7 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -2,25 +2,32 @@ import warnings from collections import OrderedDict from functools import partial -from typing import Callable, Dict, List, Tuple, Optional +from typing import Any, Callable, Dict, List, Tuple, Optional import torch from torch import nn, Tensor -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import sigmoid_focal_loss, generalized_box_iou_loss from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 +from ...transforms import ObjectDetectionEval, InterpolationMode from ...utils import _log_api_usage_once -from ..resnet import resnet50 +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet50_Weights, resnet50 from . import _utils as det_utils from .anchor_utils import AnchorGenerator from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .transform import GeneralizedRCNNTransform -__all__ = ["FCOS", "fcos_resnet50_fpn"] +__all__ = [ + "FCOS", + "FCOS_ResNet50_FPN_Weights", + "fcos_resnet50_fpn", +] class FCOSHead(nn.Module): @@ -626,19 +633,37 @@ def forward( return self.eager_outputs(losses, detections) -model_urls = { - "fcos_resnet50_fpn_coco": "https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", -} +class FCOS_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", + transforms=ObjectDetectionEval, + meta={ + "task": "image_object_detection", + "architecture": "FCOS", + "publication_year": 2019, + "num_params": 32269600, + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn", + "map": 39.2, + }, + ) + DEFAULT = COCO_V1 +@handle_legacy_interface( + weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def fcos_resnet50_fpn( - pretrained: bool = False, + *, + weights: Optional[FCOS_ResNet50_FPN_Weights] = None, progress: bool = True, - num_classes: int = 91, - pretrained_backbone: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, trainable_backbone_layers: Optional[int] = None, - **kwargs, -): + **kwargs: Any, +) -> FCOS: """ Constructs a FCOS model with a ResNet-50-FPN backbone. @@ -678,28 +703,34 @@ def fcos_resnet50_fpn( >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FCOS_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone trainable_backbone_layers (int, optional): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. Default: None """ - is_trained = pretrained or pretrained_backbone + weights = FCOS_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor( backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) ) model = FCOS(backbone, num_classes, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["fcos_resnet50_fpn_coco"], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 9f23e66e0c5..3794b253ec7 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -1,16 +1,25 @@ +from typing import Any, Optional + import torch from torch import nn from torchvision.ops import MultiScaleRoIAlign -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import misc as misc_nn_ops -from ..resnet import resnet50 +from ...transforms import ObjectDetectionEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet50_Weights, resnet50 from ._utils import overwrite_eps from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .faster_rcnn import FasterRCNN -__all__ = ["KeypointRCNN", "keypointrcnn_resnet50_fpn"] +__all__ = [ + "KeypointRCNN", + "KeypointRCNN_ResNet50_FPN_Weights", + "keypointrcnn_resnet50_fpn", +] class KeypointRCNN(FasterRCNN): @@ -293,22 +302,61 @@ def forward(self, x): ) -model_urls = { - # legacy model for BC reasons, see https://github.com/pytorch/vision/issues/1606 - "keypointrcnn_resnet50_fpn_coco_legacy": "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", - "keypointrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", +_COMMON_META = { + "task": "image_object_detection", + "architecture": "KeypointRCNN", + "publication_year": 2017, + "categories": _COCO_PERSON_CATEGORIES, + "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES, + "interpolation": InterpolationMode.BILINEAR, } +class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): + COCO_LEGACY = Weights( + url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", + transforms=ObjectDetectionEval, + meta={ + **_COMMON_META, + "num_params": 59137258, + "recipe": "https://github.com/pytorch/vision/issues/1606", + "map": 50.6, + "map_kp": 61.1, + }, + ) + COCO_V1 = Weights( + url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", + transforms=ObjectDetectionEval, + meta={ + **_COMMON_META, + "num_params": 59137258, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn", + "map": 54.6, + "map_kp": 65.0, + }, + ) + DEFAULT = COCO_V1 + + +@handle_legacy_interface( + weights=( + "pretrained", + lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY + if kwargs["pretrained"] == "legacy" + else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1, + ), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def keypointrcnn_resnet50_fpn( - pretrained=False, - progress=True, - num_classes=2, - num_keypoints=17, - pretrained_backbone=True, - trainable_backbone_layers=None, - **kwargs, -): + *, + weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + num_keypoints: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> KeypointRCNN: """ Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. @@ -356,31 +404,39 @@ def keypointrcnn_resnet50_fpn( >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (KeypointRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - num_keypoints (int): number of keypoints, default 17 - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + num_keypoints (int, optional): number of keypoints + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - is_trained = pretrained or pretrained_backbone + weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_keypoints = _ovewrite_value_param(num_keypoints, len(weights.meta["keypoint_names"])) + else: + if num_classes is None: + num_classes = 2 + if num_keypoints is None: + num_keypoints = 17 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) - if pretrained: - key = "keypointrcnn_resnet50_fpn_coco" - if pretrained == "legacy": - key += "_legacy" - state_dict = load_state_dict_from_url(model_urls[key], progress=progress) - model.load_state_dict(state_dict) - overwrite_eps(model, 0.0) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + return model diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 37f88116c5e..38ba82af01d 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -1,17 +1,23 @@ from collections import OrderedDict +from typing import Any, Optional from torch import nn from torchvision.ops import MultiScaleRoIAlign -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import misc as misc_nn_ops -from ..resnet import resnet50 +from ...transforms import ObjectDetectionEval, InterpolationMode +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet50_Weights, resnet50 from ._utils import overwrite_eps from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .faster_rcnn import FasterRCNN + __all__ = [ "MaskRCNN", + "MaskRCNN_ResNet50_FPN_Weights", "maskrcnn_resnet50_fpn", ] @@ -296,14 +302,38 @@ def __init__(self, in_channels, dim_reduced, num_classes): # nn.init.constant_(param, 0) -model_urls = { - "maskrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", -} - - +class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", + transforms=ObjectDetectionEval, + meta={ + "task": "image_object_detection", + "architecture": "MaskRCNN", + "publication_year": 2017, + "num_params": 44401393, + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn", + "map": 37.9, + "map_mask": 34.6, + }, + ) + DEFAULT = COCO_V1 + + +@handle_legacy_interface( + weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def maskrcnn_resnet50_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> MaskRCNN: """ Constructs a Mask R-CNN model with a ResNet-50-FPN backbone. @@ -352,27 +382,34 @@ def maskrcnn_resnet50_fpn( >>> torch.onnx.export(model, x, "mask_rcnn.onnx", opset_version = 11) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (MaskRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - is_trained = pretrained or pretrained_backbone + weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = MaskRCNN(backbone, num_classes, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["maskrcnn_resnet50_fpn_coco"], progress=progress) - model.load_state_dict(state_dict) - overwrite_eps(model, 0.0) + model = MaskRCNN(backbone, num_classes=num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + return model diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 4f79b5ddbfc..b1c371583bf 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -1,18 +1,21 @@ import math import warnings from collections import OrderedDict -from typing import Dict, List, Tuple, Optional +from typing import Any, Dict, List, Tuple, Optional import torch from torch import nn, Tensor -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import sigmoid_focal_loss from ...ops import boxes as box_ops from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 +from ...transforms import ObjectDetectionEval, InterpolationMode from ...utils import _log_api_usage_once -from ..resnet import resnet50 +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet50_Weights, resnet50 from . import _utils as det_utils from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator @@ -20,7 +23,11 @@ from .transform import GeneralizedRCNNTransform -__all__ = ["RetinaNet", "retinanet_resnet50_fpn"] +__all__ = [ + "RetinaNet", + "RetinaNet_ResNet50_FPN_Weights", + "retinanet_resnet50_fpn", +] def _sum(x: List[Tensor]) -> Tensor: @@ -571,14 +578,37 @@ def forward(self, images, targets=None): return self.eager_outputs(losses, detections) -model_urls = { - "retinanet_resnet50_fpn_coco": "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", -} +class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", + transforms=ObjectDetectionEval, + meta={ + "task": "image_object_detection", + "architecture": "RetinaNet", + "publication_year": 2017, + "num_params": 34014999, + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", + "map": 36.4, + }, + ) + DEFAULT = COCO_V1 +@handle_legacy_interface( + weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), +) def retinanet_resnet50_fpn( - pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs -): + *, + weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> RetinaNet: """ Constructs a RetinaNet model with a ResNet-50-FPN backbone. @@ -618,30 +648,37 @@ def retinanet_resnet50_fpn( >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (RetinaNet_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 3. """ - is_trained = pretrained or pretrained_backbone + weights = RetinaNet_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + is_trained = weights is not None or weights_backbone is not None trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - - backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) # skip P2 because it generates too many anchors (according to their paper) backbone = _resnet_fpn_extractor( backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) ) model = RetinaNet(backbone, num_classes, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls["retinanet_resnet50_fpn_coco"], progress=progress) - model.load_state_dict(state_dict) - overwrite_eps(model, 0.0) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1: + overwrite_eps(model, 0.0) + return model diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index b7bbb81111e..ab901449f51 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -4,8 +4,7 @@ import torch.nn.functional as F import torchvision from torch import nn, Tensor -from torchvision.ops import boxes as box_ops -from torchvision.ops import roi_align +from torchvision.ops import boxes as box_ops, roi_align from . import _utils as det_utils diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index 08a9ed68e4e..cf3becc5fc4 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -6,27 +6,42 @@ import torch.nn.functional as F from torch import nn, Tensor -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import boxes as box_ops +from ...transforms import ObjectDetectionEval, InterpolationMode from ...utils import _log_api_usage_once -from .. import vgg +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..vgg import VGG, VGG16_Weights, vgg16 from . import _utils as det_utils from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers from .transform import GeneralizedRCNNTransform -__all__ = ["SSD", "ssd300_vgg16"] -model_urls = { - "ssd300_vgg16_coco": "https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", -} - -backbone_urls = { - # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the - # same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth - # Only the `features` weights have proper values, those on the `classifier` module are filled with nans. - "vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth" -} +__all__ = [ + "SSD300_VGG16_Weights", + "ssd300_vgg16", +] + + +class SSD300_VGG16_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", + transforms=ObjectDetectionEval, + meta={ + "task": "image_object_detection", + "architecture": "SSD", + "publication_year": 2015, + "num_params": 35641826, + "size": (300, 300), + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16", + "map": 25.1, + }, + ) + DEFAULT = COCO_V1 def _xavier_init(conv: nn.Module): @@ -520,7 +535,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return OrderedDict([(str(i), v) for i, v in enumerate(output)]) -def _vgg_extractor(backbone: vgg.VGG, highres: bool, trainable_layers: int): +def _vgg_extractor(backbone: VGG, highres: bool, trainable_layers: int): backbone = backbone.features # Gather the indices of maxpools. These are the locations of output blocks. stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1] @@ -537,14 +552,19 @@ def _vgg_extractor(backbone: vgg.VGG, highres: bool, trainable_layers: int): return SSDFeatureExtractorVGG(backbone, highres) +@handle_legacy_interface( + weights=("pretrained", SSD300_VGG16_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES), +) def ssd300_vgg16( - pretrained: bool = False, + *, + weights: Optional[SSD300_VGG16_Weights] = None, progress: bool = True, - num_classes: int = 91, - pretrained_backbone: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[VGG16_Weights] = None, trainable_backbone_layers: Optional[int] = None, **kwargs: Any, -): +) -> SSD: """Constructs an SSD model with input size 300x300 and a VGG16 backbone. Reference: `"SSD: Single Shot MultiBox Detector" `_. @@ -582,31 +602,32 @@ def ssd300_vgg16( >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (SSD300_VGG16_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (VGG16_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 4. """ + weights = SSD300_VGG16_Weights.verify(weights) + weights_backbone = VGG16_Weights.verify(weights_backbone) + if "size" in kwargs: - warnings.warn("The size of the model is already fixed; ignoring the argument.") + warnings.warn("The size of the model is already fixed; ignoring the parameter.") + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 4 + weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4 ) - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - # Use custom backbones more appropriate for SSD - backbone = vgg.vgg16(pretrained=False, progress=progress) - if pretrained_backbone: - state_dict = load_state_dict_from_url(backbone_urls["vgg16_features"], progress=progress) - backbone.load_state_dict(state_dict) - + backbone = vgg16(weights=weights_backbone, progress=progress) backbone = _vgg_extractor(backbone, False, trainable_backbone_layers) anchor_generator = DefaultBoxGenerator( [[2], [2, 3], [2, 3], [2, 3], [2], [2]], @@ -619,12 +640,10 @@ def ssd300_vgg16( "image_mean": [0.48235, 0.45882, 0.40784], "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor } - kwargs = {**defaults, **kwargs} + kwargs: Any = {**defaults, **kwargs} model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) - if pretrained: - weights_name = "ssd300_vgg16_coco" - if model_urls.get(weights_name, None) is None: - raise ValueError(f"No checkpoint is available for model {weights_name}") - state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) - model.load_state_dict(state_dict) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 1ee59e069ea..a71da6b29ac 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -6,21 +6,24 @@ import torch from torch import nn, Tensor -from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import Conv2dNormActivation +from ...transforms import ObjectDetectionEval, InterpolationMode from ...utils import _log_api_usage_once from .. import mobilenet +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large from . import _utils as det_utils from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers from .ssd import SSD, SSDScoringHead -__all__ = ["ssdlite320_mobilenet_v3_large"] - -model_urls = { - "ssdlite320_mobilenet_v3_large_coco": "https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth" -} +__all__ = [ + "SSDLite320_MobileNet_V3_Large_Weights", + "ssdlite320_mobilenet_v3_large", +] # Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paper @@ -178,15 +181,39 @@ def _mobilenet_extractor( return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer) +class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", + transforms=ObjectDetectionEval, + meta={ + "task": "image_object_detection", + "architecture": "SSDLite", + "publication_year": 2018, + "num_params": 3440060, + "size": (320, 320), + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large", + "map": 21.3, + }, + ) + DEFAULT = COCO_V1 + + +@handle_legacy_interface( + weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), +) def ssdlite320_mobilenet_v3_large( - pretrained: bool = False, + *, + weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None, progress: bool = True, - num_classes: int = 91, - pretrained_backbone: bool = False, + num_classes: Optional[int] = None, + weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, trainable_backbone_layers: Optional[int] = None, norm_layer: Optional[Callable[..., nn.Module]] = None, **kwargs: Any, -): +) -> SSD: """Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone, as described at `"Searching for MobileNetV3" `_ and @@ -203,35 +230,41 @@ def ssdlite320_mobilenet_v3_large( >>> predictions = model(x) Args: - pretrained (bool): If True, returns a model pre-trained on COCO train2017 + weights (FasterRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model progress (bool): If True, displays a progress bar of the download to stderr - num_classes (int): number of output classes of the model (including the background) - pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet - trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + num_classes (int, optional): number of output classes of the model (including the background) + weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone + trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is passed (the default) this value is set to 6. norm_layer (callable, optional): Module specifying the normalization layer to use. """ + weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights) + weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) + if "size" in kwargs: - warnings.warn("The size of the model is already fixed; ignoring the argument.") + warnings.warn("The size of the model is already fixed; ignoring the parameter.") + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6 + weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 6 ) - if pretrained: - pretrained_backbone = False - # Enable reduced tail if no pretrained backbone is selected. See Table 6 of MobileNetV3 paper. - reduce_tail = not pretrained_backbone + reduce_tail = weights_backbone is None if norm_layer is None: norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) - backbone = mobilenet.mobilenet_v3_large( - pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs + backbone = mobilenet_v3_large( + weights=weights_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs ) - if not pretrained_backbone: + if weights_backbone is None: # Change the default initialization scheme if not pretrained _normal_init(backbone) backbone = _mobilenet_extractor( @@ -252,11 +285,11 @@ def ssdlite320_mobilenet_v3_large( "detections_per_img": 300, "topk_candidates": 300, # Rescale the input in a way compatible to the backbone: - # The following mean/std rescale the data from [0, 1] to [-1, 1] + # The following mean/std rescale the data from [0, 1] to [-1, -1] "image_mean": [0.5, 0.5, 0.5], "image_std": [0.5, 0.5, 0.5], } - kwargs = {**defaults, **kwargs} + kwargs: Any = {**defaults, **kwargs} model = SSD( backbone, anchor_generator, @@ -266,10 +299,7 @@ def ssdlite320_mobilenet_v3_large( **kwargs, ) - if pretrained: - weights_name = "ssdlite320_mobilenet_v3_large_coco" - if model_urls.get(weights_name, None) is None: - raise ValueError(f"No checkpoint is available for model {weights_name}") - state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) - model.load_state_dict(state_dict) + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + return model diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 5393827b293..27325c9016c 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -97,7 +97,8 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG: if weights is not None: - _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + if weights.meta["categories"] is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index 5988c160aad..3d7baca6284 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -1,4 +1,3 @@ -from . import detection from . import optical_flow from . import segmentation from . import video diff --git a/torchvision/prototype/models/detection/__init__.py b/torchvision/prototype/models/detection/__init__.py deleted file mode 100644 index 4146651c737..00000000000 --- a/torchvision/prototype/models/detection/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .faster_rcnn import * -from .fcos import * -from .keypoint_rcnn import * -from .mask_rcnn import * -from .retinanet import * -from .ssd import * -from .ssdlite import * diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py deleted file mode 100644 index 5abc0eef1c4..00000000000 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ /dev/null @@ -1,226 +0,0 @@ -from typing import Any, Optional, Union - -from torch import nn -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.detection.faster_rcnn import ( - _mobilenet_extractor, - _resnet_fpn_extractor, - _validate_trainable_layers, - AnchorGenerator, - FasterRCNN, - misc_nn_ops, - overwrite_eps, -) -from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from torchvision.models.resnet import ResNet50_Weights, resnet50 -from torchvision.transforms import ObjectDetectionEval, InterpolationMode - - -__all__ = [ - "FasterRCNN", - "FasterRCNN_ResNet50_FPN_Weights", - "FasterRCNN_MobileNet_V3_Large_FPN_Weights", - "FasterRCNN_MobileNet_V3_Large_320_FPN_Weights", - "fasterrcnn_resnet50_fpn", - "fasterrcnn_mobilenet_v3_large_fpn", - "fasterrcnn_mobilenet_v3_large_320_fpn", -] - - -_COMMON_META = { - "task": "image_object_detection", - "architecture": "FasterRCNN", - "publication_year": 2015, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 41755286, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", - "map": 37.0, - }, - ) - DEFAULT = COCO_V1 - - -class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 19386354, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", - "map": 32.8, - }, - ) - DEFAULT = COCO_V1 - - -class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 19386354, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", - "map": 22.8, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", FasterRCNN_ResNet50_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def fasterrcnn_resnet50_fpn( - *, - weights: Optional[FasterRCNN_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> FasterRCNN: - weights = FasterRCNN_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1: - overwrite_eps(model, 0.0) - - return model - - -def _fasterrcnn_mobilenet_v3_large_fpn( - *, - weights: Optional[Union[FasterRCNN_MobileNet_V3_Large_FPN_Weights, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights]], - progress: bool, - num_classes: Optional[int], - weights_backbone: Optional[MobileNet_V3_Large_Weights], - trainable_backbone_layers: Optional[int], - **kwargs: Any, -) -> FasterRCNN: - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 6, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) - anchor_sizes = ( - ( - 32, - 64, - 128, - 256, - 512, - ), - ) * 3 - aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - model = FasterRCNN( - backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model - - -@handle_legacy_interface( - weights=("pretrained", FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def fasterrcnn_mobilenet_v3_large_fpn( - *, - weights: Optional[FasterRCNN_MobileNet_V3_Large_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> FasterRCNN: - weights = FasterRCNN_MobileNet_V3_Large_FPN_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - defaults = { - "rpn_score_thresh": 0.05, - } - - kwargs = {**defaults, **kwargs} - return _fasterrcnn_mobilenet_v3_large_fpn( - weights=weights, - progress=progress, - num_classes=num_classes, - weights_backbone=weights_backbone, - trainable_backbone_layers=trainable_backbone_layers, - **kwargs, - ) - - -@handle_legacy_interface( - weights=("pretrained", FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def fasterrcnn_mobilenet_v3_large_320_fpn( - *, - weights: Optional[FasterRCNN_MobileNet_V3_Large_320_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> FasterRCNN: - - weights = FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - defaults = { - "min_size": 320, - "max_size": 640, - "rpn_pre_nms_top_n_test": 150, - "rpn_post_nms_top_n_test": 150, - "rpn_score_thresh": 0.05, - } - - kwargs = {**defaults, **kwargs} - return _fasterrcnn_mobilenet_v3_large_fpn( - weights=weights, - progress=progress, - num_classes=num_classes, - weights_backbone=weights_backbone, - trainable_backbone_layers=trainable_backbone_layers, - **kwargs, - ) diff --git a/torchvision/prototype/models/detection/fcos.py b/torchvision/prototype/models/detection/fcos.py deleted file mode 100644 index 930b26e46c8..00000000000 --- a/torchvision/prototype/models/detection/fcos.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import Any, Optional - -from torch import nn -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.detection.fcos import ( - _resnet_fpn_extractor, - _validate_trainable_layers, - FCOS, - LastLevelP6P7, - misc_nn_ops, -) -from torchvision.models.resnet import ResNet50_Weights, resnet50 -from torchvision.transforms import ObjectDetectionEval, InterpolationMode - - -__all__ = [ - "FCOS", - "FCOS_ResNet50_FPN_Weights", - "fcos_resnet50_fpn", -] - - -class FCOS_ResNet50_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "FCOS", - "publication_year": 2019, - "num_params": 32269600, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn", - "map": 39.2, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def fcos_resnet50_fpn( - *, - weights: Optional[FCOS_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> FCOS: - weights = FCOS_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _resnet_fpn_extractor( - backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) - ) - model = FCOS(backbone, num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/detection/keypoint_rcnn.py b/torchvision/prototype/models/detection/keypoint_rcnn.py deleted file mode 100644 index a7780cc9f63..00000000000 --- a/torchvision/prototype/models/detection/keypoint_rcnn.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import Any, Optional - -from torch import nn -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.detection.keypoint_rcnn import ( - _resnet_fpn_extractor, - _validate_trainable_layers, - KeypointRCNN, - misc_nn_ops, - overwrite_eps, -) -from torchvision.models.resnet import ResNet50_Weights, resnet50 -from torchvision.transforms import ObjectDetectionEval, InterpolationMode - - -__all__ = [ - "KeypointRCNN", - "KeypointRCNN_ResNet50_FPN_Weights", - "keypointrcnn_resnet50_fpn", -] - - -_COMMON_META = { - "task": "image_object_detection", - "architecture": "KeypointRCNN", - "publication_year": 2017, - "categories": _COCO_PERSON_CATEGORIES, - "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES, - "interpolation": InterpolationMode.BILINEAR, -} - - -class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): - COCO_LEGACY = Weights( - url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 59137258, - "recipe": "https://github.com/pytorch/vision/issues/1606", - "map": 50.6, - "map_kp": 61.1, - }, - ) - COCO_V1 = Weights( - url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", - transforms=ObjectDetectionEval, - meta={ - **_COMMON_META, - "num_params": 59137258, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn", - "map": 54.6, - "map_kp": 65.0, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=( - "pretrained", - lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY - if kwargs["pretrained"] == "legacy" - else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1, - ), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def keypointrcnn_resnet50_fpn( - *, - weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - num_keypoints: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> KeypointRCNN: - weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - num_keypoints = _ovewrite_value_param(num_keypoints, len(weights.meta["keypoint_names"])) - else: - if num_classes is None: - num_classes = 2 - if num_keypoints is None: - num_keypoints = 17 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1: - overwrite_eps(model, 0.0) - - return model diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py deleted file mode 100644 index d52ebe61be1..00000000000 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Any, Optional - -from torch import nn -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.detection.mask_rcnn import ( - _resnet_fpn_extractor, - _validate_trainable_layers, - MaskRCNN, - misc_nn_ops, - overwrite_eps, -) -from torchvision.models.resnet import ResNet50_Weights, resnet50 -from torchvision.transforms import ObjectDetectionEval, InterpolationMode - - -__all__ = [ - "MaskRCNN", - "MaskRCNN_ResNet50_FPN_Weights", - "maskrcnn_resnet50_fpn", -] - - -class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "MaskRCNN", - "publication_year": 2017, - "num_params": 44401393, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn", - "map": 37.9, - "map_mask": 34.6, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", MaskRCNN_ResNet50_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def maskrcnn_resnet50_fpn( - *, - weights: Optional[MaskRCNN_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> MaskRCNN: - weights = MaskRCNN_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) - model = MaskRCNN(backbone, num_classes=num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == MaskRCNN_ResNet50_FPN_Weights.COCO_V1: - overwrite_eps(model, 0.0) - - return model diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py deleted file mode 100644 index c4249118b70..00000000000 --- a/torchvision/prototype/models/detection/retinanet.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import Any, Optional - -from torch import nn -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.detection.retinanet import ( - _resnet_fpn_extractor, - _validate_trainable_layers, - RetinaNet, - LastLevelP6P7, - misc_nn_ops, - overwrite_eps, -) -from torchvision.models.resnet import ResNet50_Weights, resnet50 -from torchvision.transforms import ObjectDetectionEval, InterpolationMode - - -__all__ = [ - "RetinaNet", - "RetinaNet_ResNet50_FPN_Weights", - "retinanet_resnet50_fpn", -] - - -class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "RetinaNet", - "publication_year": 2017, - "num_params": 34014999, - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", - "map": 36.4, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1), -) -def retinanet_resnet50_fpn( - *, - weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[ResNet50_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> RetinaNet: - weights = RetinaNet_ResNet50_FPN_Weights.verify(weights) - weights_backbone = ResNet50_Weights.verify(weights_backbone) - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - is_trained = weights is not None or weights_backbone is not None - trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3) - norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d - - backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer) - # skip P2 because it generates too many anchors (according to their paper) - backbone = _resnet_fpn_extractor( - backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) - ) - model = RetinaNet(backbone, num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1: - overwrite_eps(model, 0.0) - - return model diff --git a/torchvision/prototype/models/detection/ssd.py b/torchvision/prototype/models/detection/ssd.py deleted file mode 100644 index a3c5b965deb..00000000000 --- a/torchvision/prototype/models/detection/ssd.py +++ /dev/null @@ -1,91 +0,0 @@ -import warnings -from typing import Any, Optional - -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.detection.ssd import ( - _validate_trainable_layers, - _vgg_extractor, - DefaultBoxGenerator, - SSD, -) -from torchvision.models.vgg import VGG16_Weights, vgg16 -from torchvision.transforms import ObjectDetectionEval, InterpolationMode - - -__all__ = [ - "SSD300_VGG16_Weights", - "ssd300_vgg16", -] - - -class SSD300_VGG16_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "SSD", - "publication_year": 2015, - "num_params": 35641826, - "size": (300, 300), - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssd300-vgg16", - "map": 25.1, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", SSD300_VGG16_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", VGG16_Weights.IMAGENET1K_FEATURES), -) -def ssd300_vgg16( - *, - weights: Optional[SSD300_VGG16_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[VGG16_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - **kwargs: Any, -) -> SSD: - weights = SSD300_VGG16_Weights.verify(weights) - weights_backbone = VGG16_Weights.verify(weights_backbone) - - if "size" in kwargs: - warnings.warn("The size of the model is already fixed; ignoring the parameter.") - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - trainable_backbone_layers = _validate_trainable_layers( - weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 4 - ) - - # Use custom backbones more appropriate for SSD - backbone = vgg16(weights=weights_backbone, progress=progress) - backbone = _vgg_extractor(backbone, False, trainable_backbone_layers) - anchor_generator = DefaultBoxGenerator( - [[2], [2, 3], [2, 3], [2, 3], [2], [2]], - scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], - steps=[8, 16, 32, 64, 100, 300], - ) - - defaults = { - # Rescale the input in a way compatible to the backbone - "image_mean": [0.48235, 0.45882, 0.40784], - "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor - } - kwargs: Any = {**defaults, **kwargs} - model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py deleted file mode 100644 index d9f2ee58bc6..00000000000 --- a/torchvision/prototype/models/detection/ssdlite.py +++ /dev/null @@ -1,124 +0,0 @@ -import warnings -from functools import partial -from typing import Any, Callable, Optional - -from torch import nn -from torchvision.models._api import WeightsEnum, Weights -from torchvision.models._meta import _COCO_CATEGORIES -from torchvision.models._utils import handle_legacy_interface, _ovewrite_value_param -from torchvision.models.detection.ssdlite import ( - _mobilenet_extractor, - _normal_init, - _validate_trainable_layers, - DefaultBoxGenerator, - det_utils, - SSD, - SSDLiteHead, -) -from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large -from torchvision.transforms import ObjectDetectionEval, InterpolationMode - - -__all__ = [ - "SSDLite320_MobileNet_V3_Large_Weights", - "ssdlite320_mobilenet_v3_large", -] - - -class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): - COCO_V1 = Weights( - url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", - transforms=ObjectDetectionEval, - meta={ - "task": "image_object_detection", - "architecture": "SSDLite", - "publication_year": 2018, - "num_params": 3440060, - "size": (320, 320), - "categories": _COCO_CATEGORIES, - "interpolation": InterpolationMode.BILINEAR, - "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large", - "map": 21.3, - }, - ) - DEFAULT = COCO_V1 - - -@handle_legacy_interface( - weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1), - weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1), -) -def ssdlite320_mobilenet_v3_large( - *, - weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None, - progress: bool = True, - num_classes: Optional[int] = None, - weights_backbone: Optional[MobileNet_V3_Large_Weights] = None, - trainable_backbone_layers: Optional[int] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - **kwargs: Any, -) -> SSD: - weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights) - weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone) - - if "size" in kwargs: - warnings.warn("The size of the model is already fixed; ignoring the parameter.") - - if weights is not None: - weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - elif num_classes is None: - num_classes = 91 - - trainable_backbone_layers = _validate_trainable_layers( - weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 6 - ) - - # Enable reduced tail if no pretrained backbone is selected. See Table 6 of MobileNetV3 paper. - reduce_tail = weights_backbone is None - - if norm_layer is None: - norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) - - backbone = mobilenet_v3_large( - weights=weights_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs - ) - if weights_backbone is None: - # Change the default initialization scheme if not pretrained - _normal_init(backbone) - backbone = _mobilenet_extractor( - backbone, - trainable_backbone_layers, - norm_layer, - ) - - size = (320, 320) - anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95) - out_channels = det_utils.retrieve_out_channels(backbone, size) - num_anchors = anchor_generator.num_anchors_per_location() - assert len(out_channels) == len(anchor_generator.aspect_ratios) - - defaults = { - "score_thresh": 0.001, - "nms_thresh": 0.55, - "detections_per_img": 300, - "topk_candidates": 300, - # Rescale the input in a way compatible to the backbone: - # The following mean/std rescale the data from [0, 1] to [-1, -1] - "image_mean": [0.5, 0.5, 0.5], - "image_std": [0.5, 0.5, 0.5], - } - kwargs: Any = {**defaults, **kwargs} - model = SSD( - backbone, - anchor_generator, - size, - num_classes, - head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), - **kwargs, - ) - - if weights is not None: - model.load_state_dict(weights.get_state_dict(progress=progress)) - - return model