Skip to content

Commit 082f37e

Browse files
authored
Adding multiweight support for mobilenetv2 prototype (#4784)
1 parent 79b350e commit 082f37e

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

torchvision/prototype/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .vgg import *
55
from .efficientnet import *
66
from .mobilenetv3 import *
7+
from .mobilenetv2 import *
78
from .mnasnet import *
89
from . import detection
910
from . import quantization
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, Optional
4+
5+
from torchvision.transforms.functional import InterpolationMode
6+
7+
from ...models.mobilenetv2 import MobileNetV2
8+
from ..transforms.presets import ImageNetEval
9+
from ._api import Weights, WeightEntry
10+
from ._meta import _IMAGENET_CATEGORIES
11+
12+
13+
__all__ = ["MobileNetV2", "MobileNetV2Weights", "mobilenet_v2"]
14+
15+
16+
_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR}
17+
18+
19+
class MobileNetV2Weights(Weights):
20+
ImageNet1K_RefV1 = WeightEntry(
21+
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
22+
transforms=partial(ImageNetEval, crop_size=224),
23+
meta={
24+
**_common_meta,
25+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
26+
"acc@1": 71.878,
27+
"acc@5": 90.286,
28+
},
29+
)
30+
31+
32+
def mobilenet_v2(weights: Optional[MobileNetV2Weights] = None, progress: bool = True, **kwargs: Any) -> MobileNetV2:
33+
if "pretrained" in kwargs:
34+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
35+
weights = MobileNetV2Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
36+
weights = MobileNetV2Weights.verify(weights)
37+
38+
if weights is not None:
39+
kwargs["num_classes"] = len(weights.meta["categories"])
40+
41+
model = MobileNetV2(**kwargs)
42+
43+
if weights is not None:
44+
model.load_state_dict(weights.state_dict(progress=progress))
45+
46+
return model

0 commit comments

Comments
 (0)