From 9ead4b068cd34610bbf2f724ffe58d1e8ccc3ca7 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Mon, 16 May 2022 10:32:53 +0100 Subject: [PATCH 1/5] Add weight for mnasnet0_75 and mnasnet1_3 --- torchvision/models/mnasnet.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 6f403fb5e30..3ebcdf37867 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -235,8 +235,21 @@ class MNASNet0_5_Weights(WeightsEnum): class MNASNet0_75_Weights(WeightsEnum): - # If a default model is added here the corresponding changes need to be done in mnasnet0_75 - pass + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mnasnet0_75-7090bc5f.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "MOCK" + "num_params": 3170208, + "metrics": { + # TODO: still mock need to update! + "acc@1": 71.180, + "acc@5": 90.494, + }, + }, + ) + DEFAULT = IMAGENET1K_V1 class MNASNet1_0_Weights(WeightsEnum): @@ -256,8 +269,21 @@ class MNASNet1_0_Weights(WeightsEnum): class MNASNet1_3_Weights(WeightsEnum): - # If a default model is added here the corresponding changes need to be done in mnasnet1_3 - pass + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "recipe": "MOCK" + "num_params": 6282256, + "metrics": { + # TODO: still mock need to update! + "acc@1": 76.506, + "acc@5": 93.522, + }, + }, + ) + DEFAULT = IMAGENET1K_V1 def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet: From 69f083b4c5c28eeb0d81a56f0d841f028af7c9e6 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Mon, 16 May 2022 09:43:02 +0000 Subject: [PATCH 2/5] Fix missing comma --- torchvision/models/mnasnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index 3ebcdf37867..b12ffa23d6e 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -240,7 +240,7 @@ class MNASNet0_75_Weights(WeightsEnum): transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, - "recipe": "MOCK" + "recipe": "MOCK", "num_params": 3170208, "metrics": { # TODO: still mock need to update! @@ -274,7 +274,7 @@ class MNASNet1_3_Weights(WeightsEnum): transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, - "recipe": "MOCK" + "recipe": "MOCK", "num_params": 6282256, "metrics": { # TODO: still mock need to update! From ba047b15e2924fa3dd7f6fffe08f33c8aafd61d2 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Mon, 16 May 2022 10:58:22 +0100 Subject: [PATCH 3/5] Add PR url as recipe, and update the metrics --- torchvision/models/mnasnet.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index b12ffa23d6e..c024af38b2e 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -240,12 +240,11 @@ class MNASNet0_75_Weights(WeightsEnum): transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, - "recipe": "MOCK", + "recipe": "https://github.com/pytorch/vision/pull/6019", "num_params": 3170208, "metrics": { - # TODO: still mock need to update! "acc@1": 71.180, - "acc@5": 90.494, + "acc@5": 90.496, }, }, ) @@ -274,10 +273,9 @@ class MNASNet1_3_Weights(WeightsEnum): transforms=partial(ImageClassification, crop_size=224, resize_size=232), meta={ **_COMMON_META, - "recipe": "MOCK", + "recipe": "https://github.com/pytorch/vision/pull/6019", "num_params": 6282256, "metrics": { - # TODO: still mock need to update! "acc@1": 76.506, "acc@5": 93.522, }, From 1a434f2b389deb69fe04e8c3464da1475201c267 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Mon, 16 May 2022 12:29:44 +0100 Subject: [PATCH 4/5] Add weights to legacy handler --- torchvision/models/mnasnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index c024af38b2e..bf74c7274ce 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -323,7 +323,7 @@ def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = return _mnasnet(0.5, weights, progress, **kwargs) -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1)) def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: """MNASNet with depth multiplier of 0.75 from `MnasNet: Platform-Aware Neural Architecture Search for Mobile @@ -375,7 +375,7 @@ def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = return _mnasnet(1.0, weights, progress, **kwargs) -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1)) def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet: """MNASNet with depth multiplier of 1.3 from `MnasNet: Platform-Aware Neural Architecture Search for Mobile From decf0e868f48f36385111479d81fcca5456b3fe3 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Mon, 16 May 2022 15:16:56 +0100 Subject: [PATCH 5/5] Update docs to specify there are weights available --- torchvision/models/mnasnet.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index bf74c7274ce..b1da02f4697 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -330,8 +330,10 @@ def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool `_ paper. Args: - weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): Currently - no pre-trained weights are available and by default no pre-trained + weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.MNASNet0_75_Weights` below for + more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True. @@ -382,8 +384,10 @@ def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = `_ paper. Args: - weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): Currently - no pre-trained weights are available and by default no pre-trained + weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.MNASNet1_3_Weights` below for + more details, and possible values. By default, no pre-trained weights are used. progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.