diff --git a/docs/source/models.rst b/docs/source/models.rst index fa2dfec14d9..7fbae2a55d1 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -358,14 +358,15 @@ models return the predictions of the following classes: Here are the summary of the accuracies for the models trained on the instances set of COCO train2017 and evaluated on COCO val2017. -================================== ======= ======== =========== -Network box AP mask AP keypoint AP -================================== ======= ======== =========== -Faster R-CNN ResNet-50 FPN 37.0 - - -Faster R-CNN MobileNetV3-Large FPN 23.0 - - -RetinaNet ResNet-50 FPN 36.4 - - -Mask R-CNN ResNet-50 FPN 37.9 34.6 - -================================== ======= ======== =========== +====================================== ======= ======== =========== +Network box AP mask AP keypoint AP +====================================== ======= ======== =========== +Faster R-CNN ResNet-50 FPN 37.0 - - +Faster R-CNN MobileNetV3-Large FPN 32.8 - - +Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - - +RetinaNet ResNet-50 FPN 36.4 - - +Mask R-CNN ResNet-50 FPN 37.9 34.6 - +====================================== ======= ======== =========== For person keypoint detection, the accuracies for the pre-trained models are as follows @@ -415,15 +416,16 @@ For test time, we report the time for the model evaluation and postprocessing (including mask pasting in image), but not the time for computing the precision-recall. -================================== =================== ================== =========== -Network train time (s / it) test time (s / it) memory (GB) -================================== =================== ================== =========== -Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2 -Faster R-CNN MobileNetV3-Large FPN 0.0978 0.0376 0.6 -RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1 -Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4 -Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8 -================================== =================== ================== =========== +====================================== =================== ================== =========== +Network train time (s / it) test time (s / it) memory (GB) +====================================== =================== ================== =========== +Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2 +Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0 +Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6 +RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1 +Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4 +Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8 +====================================== =================== ================== =========== Faster R-CNN @@ -431,6 +433,7 @@ Faster R-CNN .. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn .. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn +.. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn RetinaNet diff --git a/references/detection/README.md b/references/detection/README.md index e7ac6e48e11..c8eaf46da5f 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -34,6 +34,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ --lr-steps 16 22 --aspect-ratio-group-factor 3 ``` +### Faster R-CNN MobileNetV3-Large 320 FPN +``` +python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ + --dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\ + --lr-steps 16 22 --aspect-ratio-group-factor 3 +``` + ### RetinaNet ``` python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ diff --git a/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_320_fpn_expect.pkl b/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_320_fpn_expect.pkl new file mode 100644 index 00000000000..94c6261b7fc Binary files /dev/null and b/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_320_fpn_expect.pkl differ diff --git a/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_fpn_expect.pkl b/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_fpn_expect.pkl index 94c6261b7fc..f3882de4838 100644 Binary files a/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_fpn_expect.pkl and b/test/expect/ModelTester.test_fasterrcnn_mobilenet_v3_large_fpn_expect.pkl differ diff --git a/test/test_models.py b/test/test_models.py index 232a78234b9..28236598177 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -38,6 +38,7 @@ def get_available_video_models(): 'inception_v3': lambda x: x.logits, "fasterrcnn_resnet50_fpn": lambda x: x[1], "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1], + "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1], "maskrcnn_resnet50_fpn": lambda x: x[1], "keypointrcnn_resnet50_fpn": lambda x: x[1], "retinanet_resnet50_fpn": lambda x: x[1], @@ -106,8 +107,11 @@ def _test_detection_model(self, name, dev): if "retinanet" in name: # Reduce the default threshold to ensure the returned boxes are not empty. kwargs["score_thresh"] = 0.01 - elif "fasterrcnn_mobilenet" in name: + elif "fasterrcnn_mobilenet_v3_large" in name: kwargs["box_score_thresh"] = 0.02076 + if "fasterrcnn_mobilenet_v3_large_320_fpn" in name: + kwargs["rpn_pre_nms_top_n_test"] = 1000 + kwargs["rpn_post_nms_top_n_test"] = 1000 model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs) model.eval().to(device=dev) input_shape = (3, 300, 300) diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index cb35f35894b..ad976a78b09 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -97,7 +97,8 @@ def test_assign_targets_to_proposals(self): self.assertEqual(labels[0].dtype, torch.int64) def test_forward_negative_sample_frcnn(self): - for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn"]: + for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn"]: model = torchvision.models.detection.__dict__[name]( num_classes=2, min_size=100, max_size=100) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index f5b4696e2ce..c37a5632ebd 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -16,7 +16,8 @@ __all__ = [ - "FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn" + "FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_320_fpn", + "fasterrcnn_mobilenet_v3_large_fpn" ] @@ -288,8 +289,10 @@ def forward(self, x): model_urls = { 'fasterrcnn_resnet50_fpn_coco': 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth', + 'fasterrcnn_mobilenet_v3_large_320_fpn_coco': + 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth', 'fasterrcnn_mobilenet_v3_large_fpn_coco': - 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-907ea3f9.pth', + 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth' } @@ -368,16 +371,38 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, return model -def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, - trainable_backbone_layers=None, min_size=320, max_size=640, rpn_score_thresh=0.05, - **kwargs): +def _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=False, progress=True, num_classes=91, + pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) + + if pretrained: + pretrained_backbone = False + backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True, + trainable_layers=trainable_backbone_layers) + + anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3 + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) + + model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), + **kwargs) + if pretrained: + if model_urls.get(weights_name, None) is None: + raise ValueError("No checkpoint is available for model {}".format(weights_name)) + state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) + model.load_state_dict(state_dict) + return model + + +def fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, + trainable_backbone_layers=None, **kwargs): """ - Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly - to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details. + Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases. + It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details. Example:: - >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True) + >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=True) >>> model.eval() >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) @@ -389,25 +414,49 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. - min_size (int): minimum size of the image to be rescaled before feeding it to the backbone - max_size (int): maximum size of the image to be rescaled before feeding it to the backbone - rpn_score_thresh (float): during inference, only return proposals with a classification score - greater than rpn_score_thresh """ - trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) + weights_name = "fasterrcnn_mobilenet_v3_large_320_fpn_coco" + defaults = { + "min_size": 320, + "max_size": 640, + "rpn_pre_nms_top_n_test": 150, + "rpn_post_nms_top_n_test": 150, + "rpn_score_thresh": 0.05, + } - if pretrained: - pretrained_backbone = False - backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True, - trainable_layers=trainable_backbone_layers) + kwargs = {**defaults, **kwargs} + return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress, + num_classes=num_classes, pretrained_backbone=pretrained_backbone, + trainable_backbone_layers=trainable_backbone_layers, **kwargs) - anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3 - aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), - min_size=min_size, max_size=max_size, rpn_score_thresh=rpn_score_thresh, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_fpn_coco'], progress=progress) - model.load_state_dict(state_dict) - return model +def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, + trainable_backbone_layers=None, **kwargs): + """ + Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone. + It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details. + + Example:: + + >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. + """ + weights_name = "fasterrcnn_mobilenet_v3_large_fpn_coco" + defaults = { + "rpn_score_thresh": 0.05, + } + + kwargs = {**defaults, **kwargs} + return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress, + num_classes=num_classes, pretrained_backbone=pretrained_backbone, + trainable_backbone_layers=trainable_backbone_layers, **kwargs)