Skip to content

Commit c88423b

Browse files
authored
Multi-pretrained weight support - Quantized ResNet50 (#4627)
* Fixing minor issue on typing. * Sample implementation for quantized resnet50.
1 parent 6b0097b commit c88423b

File tree

4 files changed

+87
-1
lines changed

4 files changed

+87
-1
lines changed

torchvision/models/quantization/resnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def fuse_model(self) -> None:
110110

111111
def _resnet(
112112
arch: str,
113-
block: Type[Union[BasicBlock, Bottleneck]],
113+
block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
114114
layers: List[int],
115115
pretrained: bool,
116116
progress: bool,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .resnet import *
22
from . import detection
3+
from . import quantization
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .resnet import *
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, List, Optional, Type, Union
4+
5+
from ....models.quantization.resnet import (
6+
QuantizableBasicBlock,
7+
QuantizableBottleneck,
8+
QuantizableResNet,
9+
_replace_relu,
10+
quantize_model,
11+
)
12+
from ...transforms.presets import ImageNetEval
13+
from .._api import Weights, WeightEntry
14+
from .._meta import _IMAGENET_CATEGORIES
15+
from ..resnet import ResNet50Weights
16+
17+
18+
__all__ = ["QuantizableResNet", "QuantizedResNet50Weights", "resnet50"]
19+
20+
21+
def _resnet(
22+
block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],
23+
layers: List[int],
24+
weights: Optional[Weights],
25+
progress: bool,
26+
quantize: bool,
27+
**kwargs: Any,
28+
) -> QuantizableResNet:
29+
if weights is not None:
30+
kwargs["num_classes"] = len(weights.meta["categories"])
31+
if "backend" in weights.meta:
32+
kwargs["backend"] = weights.meta["backend"]
33+
backend = kwargs.pop("backend", "fbgemm")
34+
35+
model = QuantizableResNet(block, layers, **kwargs)
36+
_replace_relu(model)
37+
if quantize:
38+
quantize_model(model, backend)
39+
40+
if weights is not None:
41+
model.load_state_dict(weights.state_dict(progress=progress))
42+
43+
return model
44+
45+
46+
_common_meta = {
47+
"size": (224, 224),
48+
"categories": _IMAGENET_CATEGORIES,
49+
"backend": "fbgemm",
50+
}
51+
52+
53+
class QuantizedResNet50Weights(Weights):
54+
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
55+
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
56+
transforms=partial(ImageNetEval, crop_size=224),
57+
meta={
58+
**_common_meta,
59+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#quantized",
60+
"acc@1": 75.920,
61+
"acc@5": 92.814,
62+
},
63+
)
64+
65+
66+
def resnet50(
67+
weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None,
68+
progress: bool = True,
69+
quantize: bool = False,
70+
**kwargs: Any,
71+
) -> QuantizableResNet:
72+
if "pretrained" in kwargs:
73+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
74+
if kwargs.pop("pretrained"):
75+
weights = QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet50Weights.ImageNet1K_RefV1
76+
else:
77+
weights = None
78+
79+
if quantize:
80+
weights = QuantizedResNet50Weights.verify(weights)
81+
else:
82+
weights = ResNet50Weights.verify(weights)
83+
84+
return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)

0 commit comments

Comments
 (0)