Skip to content

Commit ce2b49c

Browse files
mthrokfacebook-github-bot
authored andcommitted
Add High-res FasterRCNN MobileNetV3 and tune Low-res for speed (#3265)
Summary: * Tag fasterrcnn mobilenetv3 model with 320, add new inference config that makes it 2x faster sacrificing a bit of mAP. * Add a high resolution fasterrcnn mobilenetv3 model. * Update tests and expected values. Reviewed By: datumbox Differential Revision: D25954564 fbshipit-source-id: f6b64981d2bc83e3577435481a569df38b427b20
1 parent 8fbf214 commit ce2b49c

File tree

7 files changed

+109
-45
lines changed

7 files changed

+109
-45
lines changed

docs/source/models.rst

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -358,14 +358,15 @@ models return the predictions of the following classes:
358358
Here are the summary of the accuracies for the models trained on
359359
the instances set of COCO train2017 and evaluated on COCO val2017.
360360

361-
================================== ======= ======== ===========
362-
Network box AP mask AP keypoint AP
363-
================================== ======= ======== ===========
364-
Faster R-CNN ResNet-50 FPN 37.0 - -
365-
Faster R-CNN MobileNetV3-Large FPN 23.0 - -
366-
RetinaNet ResNet-50 FPN 36.4 - -
367-
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
368-
================================== ======= ======== ===========
361+
====================================== ======= ======== ===========
362+
Network box AP mask AP keypoint AP
363+
====================================== ======= ======== ===========
364+
Faster R-CNN ResNet-50 FPN 37.0 - -
365+
Faster R-CNN MobileNetV3-Large FPN 32.8 - -
366+
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
367+
RetinaNet ResNet-50 FPN 36.4 - -
368+
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
369+
====================================== ======= ======== ===========
369370

370371
For person keypoint detection, the accuracies for the pre-trained
371372
models are as follows
@@ -415,22 +416,24 @@ For test time, we report the time for the model evaluation and postprocessing
415416
(including mask pasting in image), but not the time for computing the
416417
precision-recall.
417418

418-
================================== =================== ================== ===========
419-
Network train time (s / it) test time (s / it) memory (GB)
420-
================================== =================== ================== ===========
421-
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
422-
Faster R-CNN MobileNetV3-Large FPN 0.0978 0.0376 0.6
423-
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
424-
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
425-
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
426-
================================== =================== ================== ===========
419+
====================================== =================== ================== ===========
420+
Network train time (s / it) test time (s / it) memory (GB)
421+
====================================== =================== ================== ===========
422+
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
423+
Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0
424+
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
425+
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
426+
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
427+
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
428+
====================================== =================== ================== ===========
427429

428430

429431
Faster R-CNN
430432
------------
431433

432434
.. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn
433435
.. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn
436+
.. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn
434437

435438

436439
RetinaNet

references/detection/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
3434
--lr-steps 16 22 --aspect-ratio-group-factor 3
3535
```
3636

37+
### Faster R-CNN MobileNetV3-Large 320 FPN
38+
```
39+
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
40+
--dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\
41+
--lr-steps 16 22 --aspect-ratio-group-factor 3
42+
```
43+
3744
### RetinaNet
3845
```
3946
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
Binary file not shown.
Binary file not shown.

test/test_models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def get_available_video_models():
3838
'inception_v3': lambda x: x.logits,
3939
"fasterrcnn_resnet50_fpn": lambda x: x[1],
4040
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
41+
"fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
4142
"maskrcnn_resnet50_fpn": lambda x: x[1],
4243
"keypointrcnn_resnet50_fpn": lambda x: x[1],
4344
"retinanet_resnet50_fpn": lambda x: x[1],
@@ -106,8 +107,11 @@ def _test_detection_model(self, name, dev):
106107
if "retinanet" in name:
107108
# Reduce the default threshold to ensure the returned boxes are not empty.
108109
kwargs["score_thresh"] = 0.01
109-
elif "fasterrcnn_mobilenet" in name:
110+
elif "fasterrcnn_mobilenet_v3_large" in name:
110111
kwargs["box_score_thresh"] = 0.02076
112+
if "fasterrcnn_mobilenet_v3_large_320_fpn" in name:
113+
kwargs["rpn_pre_nms_top_n_test"] = 1000
114+
kwargs["rpn_post_nms_top_n_test"] = 1000
111115
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
112116
model.eval().to(device=dev)
113117
input_shape = (3, 300, 300)

test/test_models_detection_negative_samples.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def test_assign_targets_to_proposals(self):
9797
self.assertEqual(labels[0].dtype, torch.int64)
9898

9999
def test_forward_negative_sample_frcnn(self):
100-
for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn"]:
100+
for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn",
101+
"fasterrcnn_mobilenet_v3_large_320_fpn"]:
101102
model = torchvision.models.detection.__dict__[name](
102103
num_classes=2, min_size=100, max_size=100)
103104

torchvision/models/detection/faster_rcnn.py

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717

1818
__all__ = [
19-
"FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn"
19+
"FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_320_fpn",
20+
"fasterrcnn_mobilenet_v3_large_fpn"
2021
]
2122

2223

@@ -288,8 +289,10 @@ def forward(self, x):
288289
model_urls = {
289290
'fasterrcnn_resnet50_fpn_coco':
290291
'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth',
292+
'fasterrcnn_mobilenet_v3_large_320_fpn_coco':
293+
'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth',
291294
'fasterrcnn_mobilenet_v3_large_fpn_coco':
292-
'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-907ea3f9.pth',
295+
'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth'
293296
}
294297

295298

@@ -368,16 +371,38 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
368371
return model
369372

370373

371-
def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
372-
trainable_backbone_layers=None, min_size=320, max_size=640, rpn_score_thresh=0.05,
373-
**kwargs):
374+
def _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=False, progress=True, num_classes=91,
375+
pretrained_backbone=True, trainable_backbone_layers=None, **kwargs):
376+
trainable_backbone_layers = _validate_trainable_layers(
377+
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)
378+
379+
if pretrained:
380+
pretrained_backbone = False
381+
backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True,
382+
trainable_layers=trainable_backbone_layers)
383+
384+
anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3
385+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
386+
387+
model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
388+
**kwargs)
389+
if pretrained:
390+
if model_urls.get(weights_name, None) is None:
391+
raise ValueError("No checkpoint is available for model {}".format(weights_name))
392+
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
393+
model.load_state_dict(state_dict)
394+
return model
395+
396+
397+
def fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
398+
trainable_backbone_layers=None, **kwargs):
374399
"""
375-
Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly
376-
to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.
400+
Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases.
401+
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.
377402
378403
Example::
379404
380-
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
405+
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=True)
381406
>>> model.eval()
382407
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
383408
>>> predictions = model(x)
@@ -389,25 +414,49 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class
389414
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
390415
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
391416
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
392-
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
393-
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
394-
rpn_score_thresh (float): during inference, only return proposals with a classification score
395-
greater than rpn_score_thresh
396417
"""
397-
trainable_backbone_layers = _validate_trainable_layers(
398-
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)
418+
weights_name = "fasterrcnn_mobilenet_v3_large_320_fpn_coco"
419+
defaults = {
420+
"min_size": 320,
421+
"max_size": 640,
422+
"rpn_pre_nms_top_n_test": 150,
423+
"rpn_post_nms_top_n_test": 150,
424+
"rpn_score_thresh": 0.05,
425+
}
399426

400-
if pretrained:
401-
pretrained_backbone = False
402-
backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True,
403-
trainable_layers=trainable_backbone_layers)
427+
kwargs = {**defaults, **kwargs}
428+
return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress,
429+
num_classes=num_classes, pretrained_backbone=pretrained_backbone,
430+
trainable_backbone_layers=trainable_backbone_layers, **kwargs)
404431

405-
anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3
406-
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
407432

408-
model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
409-
min_size=min_size, max_size=max_size, rpn_score_thresh=rpn_score_thresh, **kwargs)
410-
if pretrained:
411-
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_fpn_coco'], progress=progress)
412-
model.load_state_dict(state_dict)
413-
return model
433+
def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
434+
trainable_backbone_layers=None, **kwargs):
435+
"""
436+
Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
437+
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.
438+
439+
Example::
440+
441+
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
442+
>>> model.eval()
443+
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
444+
>>> predictions = model(x)
445+
446+
Args:
447+
pretrained (bool): If True, returns a model pre-trained on COCO train2017
448+
progress (bool): If True, displays a progress bar of the download to stderr
449+
num_classes (int): number of output classes of the model (including the background)
450+
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
451+
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
452+
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
453+
"""
454+
weights_name = "fasterrcnn_mobilenet_v3_large_fpn_coco"
455+
defaults = {
456+
"rpn_score_thresh": 0.05,
457+
}
458+
459+
kwargs = {**defaults, **kwargs}
460+
return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress,
461+
num_classes=num_classes, pretrained_backbone=pretrained_backbone,
462+
trainable_backbone_layers=trainable_backbone_layers, **kwargs)

0 commit comments

Comments
 (0)