Skip to content

Commit 12d616c

Browse files
Vincent Moensfacebook-github-bot
Vincent Moens
authored andcommitted
[fbsync] Multi-pretrained weight support - FasterRCNN ResNet50 (#4613)
Summary: * Adding FasterRCNN ResNet50. * Refactoring to remove duplicate code. * Adding typing info. * Setting weights_backbone=None as default value. * Overwrite eps only for specific weights. Reviewed By: NicolasHug Differential Revision: D31758312 fbshipit-source-id: 714a714d897bb4b4d9da1298ad5e2606998898b9
1 parent ef972a3 commit 12d616c

File tree

7 files changed

+200
-23
lines changed

7 files changed

+200
-23
lines changed

torchvision/models/detection/backbone_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import warnings
2+
from typing import List, Optional
23

34
from torch import nn
45
from torchvision.ops import misc as misc_nn_ops
5-
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
6+
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool, ExtraFPNBlock
67

78
from .. import mobilenet
89
from .. import resnet
@@ -92,7 +93,15 @@ def resnet_fpn_backbone(
9293
default a ``LastLevelMaxPool`` is used.
9394
"""
9495
backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
96+
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)
9597

98+
99+
def _resnet_backbone_config(
100+
backbone: resnet.ResNet,
101+
trainable_layers: int,
102+
returned_layers: Optional[List[int]],
103+
extra_blocks: Optional[ExtraFPNBlock],
104+
):
96105
# select layers that wont be frozen
97106
assert 0 <= trainable_layers <= 5
98107
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .resnet import *
2+
from . import detection

torchvision/prototype/models/_meta.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,3 +1006,98 @@
10061006
"ear",
10071007
"toilet tissue",
10081008
]
1009+
1010+
# To be replaced with torchvision.datasets.find("coco").info.categories
1011+
_COCO_CATEGORIES = [
1012+
"__background__",
1013+
"person",
1014+
"bicycle",
1015+
"car",
1016+
"motorcycle",
1017+
"airplane",
1018+
"bus",
1019+
"train",
1020+
"truck",
1021+
"boat",
1022+
"traffic light",
1023+
"fire hydrant",
1024+
"N/A",
1025+
"stop sign",
1026+
"parking meter",
1027+
"bench",
1028+
"bird",
1029+
"cat",
1030+
"dog",
1031+
"horse",
1032+
"sheep",
1033+
"cow",
1034+
"elephant",
1035+
"bear",
1036+
"zebra",
1037+
"giraffe",
1038+
"N/A",
1039+
"backpack",
1040+
"umbrella",
1041+
"N/A",
1042+
"N/A",
1043+
"handbag",
1044+
"tie",
1045+
"suitcase",
1046+
"frisbee",
1047+
"skis",
1048+
"snowboard",
1049+
"sports ball",
1050+
"kite",
1051+
"baseball bat",
1052+
"baseball glove",
1053+
"skateboard",
1054+
"surfboard",
1055+
"tennis racket",
1056+
"bottle",
1057+
"N/A",
1058+
"wine glass",
1059+
"cup",
1060+
"fork",
1061+
"knife",
1062+
"spoon",
1063+
"bowl",
1064+
"banana",
1065+
"apple",
1066+
"sandwich",
1067+
"orange",
1068+
"broccoli",
1069+
"carrot",
1070+
"hot dog",
1071+
"pizza",
1072+
"donut",
1073+
"cake",
1074+
"chair",
1075+
"couch",
1076+
"potted plant",
1077+
"bed",
1078+
"N/A",
1079+
"dining table",
1080+
"N/A",
1081+
"N/A",
1082+
"toilet",
1083+
"N/A",
1084+
"tv",
1085+
"laptop",
1086+
"mouse",
1087+
"remote",
1088+
"keyboard",
1089+
"cell phone",
1090+
"microwave",
1091+
"oven",
1092+
"toaster",
1093+
"sink",
1094+
"refrigerator",
1095+
"N/A",
1096+
"book",
1097+
"clock",
1098+
"vase",
1099+
"scissors",
1100+
"teddy bear",
1101+
"hair drier",
1102+
"toothbrush",
1103+
]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .faster_rcnn import *
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config
2+
from .. import resnet
3+
4+
5+
def resnet_fpn_backbone(
6+
backbone_name,
7+
weights,
8+
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
9+
trainable_layers=3,
10+
returned_layers=None,
11+
extra_blocks=None,
12+
):
13+
backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
14+
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import warnings
2+
from typing import Any, Optional
3+
4+
from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers
5+
from ...transforms.presets import CocoEval
6+
from .._api import Weights, WeightEntry
7+
from .._meta import _COCO_CATEGORIES
8+
from ..resnet import ResNet50Weights
9+
from .backbone_utils import resnet_fpn_backbone
10+
11+
12+
__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"]
13+
14+
15+
class FasterRCNNResNet50FPNWeights(Weights):
16+
Coco_RefV1 = WeightEntry(
17+
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
18+
transforms=CocoEval,
19+
meta={
20+
"categories": _COCO_CATEGORIES,
21+
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",
22+
"map": 37.0,
23+
},
24+
)
25+
26+
27+
def fasterrcnn_resnet50_fpn(
28+
weights: Optional[FasterRCNNResNet50FPNWeights] = None,
29+
weights_backbone: Optional[ResNet50Weights] = None,
30+
progress: bool = True,
31+
num_classes: int = 91,
32+
trainable_backbone_layers: Optional[int] = None,
33+
**kwargs: Any,
34+
) -> FasterRCNN:
35+
if "pretrained" in kwargs:
36+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
37+
weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None
38+
weights = FasterRCNNResNet50FPNWeights.verify(weights)
39+
if "pretrained_backbone" in kwargs:
40+
warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")
41+
weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None
42+
weights_backbone = ResNet50Weights.verify(weights_backbone)
43+
44+
if weights is not None:
45+
weights_backbone = None
46+
num_classes = len(weights.meta["categories"])
47+
48+
trainable_backbone_layers = _validate_trainable_layers(
49+
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
50+
)
51+
52+
backbone = resnet_fpn_backbone("resnet50", weights_backbone, trainable_layers=trainable_backbone_layers)
53+
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
54+
55+
if weights is not None:
56+
model.load_state_dict(weights.state_dict(progress=progress))
57+
if weights == FasterRCNNResNet50FPNWeights.Coco_RefV1:
58+
overwrite_eps(model, 0.0)
59+
60+
return model
Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple
1+
from typing import Dict, Optional, Tuple
22

33
import torch
44
from torch import Tensor, nn
@@ -7,22 +7,19 @@
77
from ...transforms import functional as F
88

99

10-
__all__ = ["ConvertImageDtype", "ImageNetEval"]
10+
__all__ = ["CocoEval", "ImageNetEval"]
1111

1212

13-
# Allows handling of both PIL and Tensor images
14-
class ConvertImageDtype(nn.Module):
15-
def __init__(self, dtype: torch.dtype) -> None:
16-
super().__init__()
17-
self.dtype = dtype
18-
19-
def forward(self, img: Tensor) -> Tensor:
13+
class CocoEval(nn.Module):
14+
def forward(
15+
self, img: Tensor, target: Optional[Dict[str, Tensor]] = None
16+
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
2017
if not isinstance(img, Tensor):
2118
img = F.pil_to_tensor(img)
22-
return F.convert_image_dtype(img, self.dtype)
19+
return F.convert_image_dtype(img, torch.float), target
2320

2421

25-
class ImageNetEval:
22+
class ImageNetEval(nn.Module):
2623
def __init__(
2724
self,
2825
crop_size: int,
@@ -31,14 +28,14 @@ def __init__(
3128
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
3229
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
3330
) -> None:
34-
self.transforms = T.Compose(
35-
[
36-
T.Resize(resize_size, interpolation=interpolation),
37-
T.CenterCrop(crop_size),
38-
ConvertImageDtype(dtype=torch.float),
39-
T.Normalize(mean=mean, std=std),
40-
]
41-
)
42-
43-
def __call__(self, img: Tensor) -> Tensor:
44-
return self.transforms(img)
31+
super().__init__()
32+
self._resize = T.Resize(resize_size, interpolation=interpolation)
33+
self._crop = T.CenterCrop(crop_size)
34+
self._normalize = T.Normalize(mean=mean, std=std)
35+
36+
def forward(self, img: Tensor) -> Tensor:
37+
img = self._crop(self._resize(img))
38+
if not isinstance(img, Tensor):
39+
img = F.pil_to_tensor(img)
40+
img = F.convert_image_dtype(img, torch.float)
41+
return self._normalize(img)

0 commit comments

Comments
 (0)