From a1d5c243500f8931b1999c5f070bb8ab697a46ce Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 7 Dec 2021 11:39:49 +0000 Subject: [PATCH 1/9] Add raft builders and presets in prototypes --- torchvision/prototype/models/__init__.py | 1 + .../prototype/models/optical_flow/__init__.py | 1 + .../prototype/models/optical_flow/raft.py | 162 ++++++++++++++++++ torchvision/prototype/transforms/__init__.py | 2 +- torchvision/prototype/transforms/_presets.py | 34 ++++ 5 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 torchvision/prototype/models/optical_flow/__init__.py create mode 100644 torchvision/prototype/models/optical_flow/raft.py diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index 12a4738e53c..bfa44ffa720 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -12,6 +12,7 @@ from .vgg import * from .vision_transformer import * from . import detection +from . import optical_flow from . import quantization from . import segmentation from . import video diff --git a/torchvision/prototype/models/optical_flow/__init__.py b/torchvision/prototype/models/optical_flow/__init__.py new file mode 100644 index 00000000000..9dd32f25dec --- /dev/null +++ b/torchvision/prototype/models/optical_flow/__init__.py @@ -0,0 +1 @@ +from .raft import RAFT, raft_large, raft_small diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py new file mode 100644 index 00000000000..ca768981e35 --- /dev/null +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -0,0 +1,162 @@ +from typing import Optional + +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.instancenorm import InstanceNorm2d +from torchvision.models.optical_flow import RAFT, BottleneckBlock, ResidualBlock +from torchvision.models.optical_flow.raft import _raft +from torchvision.prototype.transforms import RaftEval + +from .._api import WeightsEnum, Weights + + +__all__ = ( + "RAFT", + "raft_large", + "raft_small", +) + + +class Raft_Large_Weights(WeightsEnum): + C_T = Weights( + # Chairs + Things + url="", + transforms=RaftEval, + meta={ + "recipe": "", + "epe": -1234, + }, + ) + + C_T_SKHT = Weights( + # Chairs + Things + Sintel fine-tuning, i.e.: + # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel + url="", + transforms=RaftEval, + meta={ + "recipe": "", + "epe": -1234, + }, + ) + + C_T_SKHT_K = Weights( + # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: + # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti + # Same as CT_SKHT with extra fine-tuning on Kitti + # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti + url="", + transforms=RaftEval, + meta={ + "recipe": "", + "epe": -1234, + }, + ) + + default = C_T + + +class Raft_Small_Weights(WeightsEnum): + C_T = Weights( + url="", # TODO + transforms=RaftEval, + meta={ + "recipe": "", + "epe": -1234, + }, + ) + default = C_T + + +def raft_large(weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): + """RAFT model from + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. + + Args: + weights(Raft_Large_weights, optinal): TODO not implemented yet + progress (bool): If True, displays a progress bar of the download to stderr + kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class + to override any default. + + Returns: + nn.Module: The model. + """ + + if weights is not None: + raise ValueError("Pretrained weights aren't available yet") + + weights = Raft_Large_Weights.verify(weights) + + return _raft( + # Feature encoder + feature_encoder_layers=(64, 64, 96, 128, 256), + feature_encoder_block=ResidualBlock, + feature_encoder_norm_layer=InstanceNorm2d, + # Context encoder + context_encoder_layers=(64, 64, 96, 128, 256), + context_encoder_block=ResidualBlock, + context_encoder_norm_layer=BatchNorm2d, + # Correlation block + corr_block_num_levels=4, + corr_block_radius=4, + # Motion encoder + motion_encoder_corr_layers=(256, 192), + motion_encoder_flow_layers=(128, 64), + motion_encoder_out_channels=128, + # Recurrent block + recurrent_block_hidden_state_size=128, + recurrent_block_kernel_size=((1, 5), (5, 1)), + recurrent_block_padding=((0, 2), (2, 0)), + # Flow head + flow_head_hidden_size=256, + # Mask predictor + use_mask_predictor=True, + **kwargs, + ) + + +def raft_small(weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): + """RAFT "small" model from + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. + + Args: + weights(Raft_Small_weights, optinal): TODO not implemented yet + progress (bool): If True, displays a progress bar of the download to stderr + kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class + to override any default. + + Returns: + nn.Module: The model. + + """ + + if weights is not None: + raise ValueError("Pretrained weights aren't available yet") + + weights = Raft_Small_Weights.verify(weights) + + return _raft( + # Feature encoder + feature_encoder_layers=(32, 32, 64, 96, 128), + feature_encoder_block=BottleneckBlock, + feature_encoder_norm_layer=InstanceNorm2d, + # Context encoder + context_encoder_layers=(32, 32, 64, 96, 160), + context_encoder_block=BottleneckBlock, + context_encoder_norm_layer=None, + # Correlation block + corr_block_num_levels=4, + corr_block_radius=3, + # Motion encoder + motion_encoder_corr_layers=(96,), + motion_encoder_flow_layers=(64, 32), + motion_encoder_out_channels=82, + # Recurrent block + recurrent_block_hidden_state_size=96, + recurrent_block_kernel_size=(3,), + recurrent_block_padding=(1,), + # Flow head + flow_head_hidden_size=128, + # Mask predictor + use_mask_predictor=False, + **kwargs, + ) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index c91542933b8..56cca7b0402 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -3,4 +3,4 @@ from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop from ._misc import Identity, Normalize -from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval +from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py index 3b9d733d8df..7edafaf7b54 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/prototype/transforms/_presets.py @@ -1,5 +1,6 @@ from typing import Dict, Optional, Tuple +import numpy as np import torch from torch import Tensor, nn @@ -97,3 +98,36 @@ def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, target = F.pil_to_tensor(target) target = target.squeeze(0).to(torch.int64) return img, target + + +class RaftEval(nn.Module): + def forward( + self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + + img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask) + + img1 = F.convert_image_dtype(img1, torch.float32) + img2 = F.convert_image_dtype(img2, torch.float32) + + # map [0, 1] into [-1, 1] + img1 = F.normalize(img1, mean=0.5, std=0.5) + img2 = F.normalize(img2, mean=0.5, std=0.5) + + img1 = img1.contiguous() + img2 = img2.contiguous() + + return img1, img2, flow, valid_flow_mask + + def _pil_or_numpy_to_tensor(self, img1, img2, flow, valid_flow_mask): + if not isinstance(img1, Tensor): + img1 = F.pil_to_tensor(img1) + if not isinstance(img2, Tensor): + img2 = F.pil_to_tensor(img2) + + if flow is not None and not isinstance(flow, np.ndarray): + flow = torch.from_numpy(flow) + if valid_flow_mask is not None and not isinstance(flow, np.ndarray): + valid_flow_mask = torch.from_numpy(valid_flow_mask) + + return img1, img2, flow, valid_flow_mask From b436fc0a37ef37c7f7cedf3fc5deaa1ce1ced053 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 7 Dec 2021 13:09:09 +0000 Subject: [PATCH 2/9] Switch import --- torchvision/prototype/models/optical_flow/raft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index ca768981e35..14b23eac8ba 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -2,8 +2,8 @@ from torch.nn.modules.batchnorm import BatchNorm2d from torch.nn.modules.instancenorm import InstanceNorm2d -from torchvision.models.optical_flow import RAFT, BottleneckBlock, ResidualBlock -from torchvision.models.optical_flow.raft import _raft +from torchvision.models.optical_flow import RAFT +from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock from torchvision.prototype.transforms import RaftEval from .._api import WeightsEnum, Weights From 1f58b779cc23f644b581db8caaead5b94d2eda21 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 7 Dec 2021 13:30:42 +0000 Subject: [PATCH 3/9] Update torchvision/prototype/transforms/_presets.py Co-authored-by: Vasilis Vryniotis --- torchvision/prototype/transforms/_presets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py index 7edafaf7b54..b0502a3a41d 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/prototype/transforms/_presets.py @@ -119,7 +119,7 @@ def forward( return img1, img2, flow, valid_flow_mask - def _pil_or_numpy_to_tensor(self, img1, img2, flow, valid_flow_mask): + def _pil_or_numpy_to_tensor(self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor])-> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: if not isinstance(img1, Tensor): img1 = F.pil_to_tensor(img1) if not isinstance(img2, Tensor): From c8d65b68bd3063b7517c17fad70f38438e21a2e3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 7 Dec 2021 13:47:50 +0000 Subject: [PATCH 4/9] Address comments --- torchvision/models/optical_flow/raft.py | 4 +- .../prototype/models/optical_flow/raft.py | 107 +++++++++--------- torchvision/prototype/transforms/_presets.py | 13 ++- 3 files changed, 65 insertions(+), 59 deletions(-) diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index ba1cc8499d8..f653895598f 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -585,7 +585,7 @@ def raft_large(*, pretrained=False, progress=True, **kwargs): """ if pretrained: - raise ValueError("Pretrained weights aren't available yet") + raise ValueError("No checkpoint is available for raft_large") return _raft( # Feature encoder @@ -631,7 +631,7 @@ def raft_small(*, pretrained=False, progress=True, **kwargs): """ if pretrained: - raise ValueError("Pretrained weights aren't available yet") + raise ValueError("No checkpoint is available for raft_small") return _raft( # Feature encoder diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index 14b23eac8ba..1f13bdc914b 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -7,6 +7,7 @@ from torchvision.prototype.transforms import RaftEval from .._api import WeightsEnum, Weights +from .._utils import handle_legacy_interface __all__ = ( @@ -17,57 +18,60 @@ class Raft_Large_Weights(WeightsEnum): - C_T = Weights( - # Chairs + Things - url="", - transforms=RaftEval, - meta={ - "recipe": "", - "epe": -1234, - }, - ) - - C_T_SKHT = Weights( - # Chairs + Things + Sintel fine-tuning, i.e.: - # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) - # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel - url="", - transforms=RaftEval, - meta={ - "recipe": "", - "epe": -1234, - }, - ) - - C_T_SKHT_K = Weights( - # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: - # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti - # Same as CT_SKHT with extra fine-tuning on Kitti - # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti - url="", - transforms=RaftEval, - meta={ - "recipe": "", - "epe": -1234, - }, - ) - - default = C_T + pass + # C_T_V1 = Weights( + # # Chairs + Things + # url="", + # transforms=RaftEval, + # meta={ + # "recipe": "", + # "epe": -1234, + # }, + # ) + + # C_T_SKHT_V1 = Weights( + # # Chairs + Things + Sintel fine-tuning, i.e.: + # # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + # # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel + # url="", + # transforms=RaftEval, + # meta={ + # "recipe": "", + # "epe": -1234, + # }, + # ) + + # C_T_SKHT_K_V1 = Weights( + # # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: + # # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti + # # Same as CT_SKHT with extra fine-tuning on Kitti + # # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti + # url="", + # transforms=RaftEval, + # meta={ + # "recipe": "", + # "epe": -1234, + # }, + # ) + + # default = C_T_V1 class Raft_Small_Weights(WeightsEnum): - C_T = Weights( - url="", # TODO - transforms=RaftEval, - meta={ - "recipe": "", - "epe": -1234, - }, - ) - default = C_T - - -def raft_large(weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): + pass + # C_T_V1 = Weights( + # url="", # TODO + # transforms=RaftEval, + # meta={ + # "recipe": "", + # "epe": -1234, + # }, + # ) + # default = C_T_V1 + + +@handle_legacy_interface(weights=("pretrained", None)) +def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): """RAFT model from `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. @@ -82,7 +86,7 @@ def raft_large(weights: Optional[Raft_Large_Weights] = None, progress=True, **kw """ if weights is not None: - raise ValueError("Pretrained weights aren't available yet") + raise ValueError("No checkpoint is available for raft_large") weights = Raft_Large_Weights.verify(weights) @@ -114,7 +118,8 @@ def raft_large(weights: Optional[Raft_Large_Weights] = None, progress=True, **kw ) -def raft_small(weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): +@handle_legacy_interface(weights=("pretrained", None)) +def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): """RAFT "small" model from `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. @@ -130,7 +135,7 @@ def raft_small(weights: Optional[Raft_Small_Weights] = None, progress=True, **kw """ if weights is not None: - raise ValueError("Pretrained weights aren't available yet") + raise ValueError("No checkpoint is available for raft_small") weights = Raft_Small_Weights.verify(weights) diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py index b0502a3a41d..2f920bf0e72 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/prototype/transforms/_presets.py @@ -1,6 +1,5 @@ from typing import Dict, Optional, Tuple -import numpy as np import torch from torch import Tensor, nn @@ -111,23 +110,25 @@ def forward( img2 = F.convert_image_dtype(img2, torch.float32) # map [0, 1] into [-1, 1] - img1 = F.normalize(img1, mean=0.5, std=0.5) - img2 = F.normalize(img2, mean=0.5, std=0.5) + img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) img1 = img1.contiguous() img2 = img2.contiguous() return img1, img2, flow, valid_flow_mask - def _pil_or_numpy_to_tensor(self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor])-> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + def _pil_or_numpy_to_tensor( + self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: if not isinstance(img1, Tensor): img1 = F.pil_to_tensor(img1) if not isinstance(img2, Tensor): img2 = F.pil_to_tensor(img2) - if flow is not None and not isinstance(flow, np.ndarray): + if flow is not None and not isinstance(flow, Tensor): flow = torch.from_numpy(flow) - if valid_flow_mask is not None and not isinstance(flow, np.ndarray): + if valid_flow_mask is not None and not isinstance(flow, Tensor): valid_flow_mask = torch.from_numpy(valid_flow_mask) return img1, img2, flow, valid_flow_mask From 17522f4fb268e4e2c400a403f055197b75a1dd60 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 7 Dec 2021 13:50:07 +0000 Subject: [PATCH 5/9] Comment out unsued imports --- torchvision/prototype/models/optical_flow/raft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index 1f13bdc914b..24141d04368 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -4,9 +4,9 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.models.optical_flow import RAFT from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock -from torchvision.prototype.transforms import RaftEval +# from torchvision.prototype.transforms import RaftEval -from .._api import WeightsEnum, Weights +# from .._api import WeightsEnum, Weights from .._utils import handle_legacy_interface From 1dc06ba58877719007d8b10931cc746d091479e9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 7 Dec 2021 13:50:40 +0000 Subject: [PATCH 6/9] Comment out unsued imports --- torchvision/prototype/models/optical_flow/raft.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index 24141d04368..a459af5b73f 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -6,7 +6,8 @@ from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock # from torchvision.prototype.transforms import RaftEval -# from .._api import WeightsEnum, Weights +from .._api import WeightsEnum +# from .._api import Weights from .._utils import handle_legacy_interface From 48674ba034233b335c4ab5dea52d2ce446824084 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 7 Dec 2021 13:51:23 +0000 Subject: [PATCH 7/9] typo flow -> valid_flow_mask --- torchvision/prototype/transforms/_presets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py index 2f920bf0e72..f18395f4063 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/prototype/transforms/_presets.py @@ -128,7 +128,7 @@ def _pil_or_numpy_to_tensor( if flow is not None and not isinstance(flow, Tensor): flow = torch.from_numpy(flow) - if valid_flow_mask is not None and not isinstance(flow, Tensor): + if valid_flow_mask is not None and not isinstance(valid_flow_mask, Tensor): valid_flow_mask = torch.from_numpy(valid_flow_mask) return img1, img2, flow, valid_flow_mask From 68d1f412edef611a813e56a8cfd13209d0c615a1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 7 Dec 2021 19:55:25 +0000 Subject: [PATCH 8/9] Add tests --- test/test_prototype_models.py | 20 ++++++++++++++++--- .../prototype/models/optical_flow/__init__.py | 2 +- .../prototype/models/optical_flow/raft.py | 10 ++++------ torchvision/prototype/transforms/_presets.py | 2 +- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 3286785f60a..ae577f31b3e 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -3,7 +3,7 @@ import pytest import test_models as TM import torch -from common_utils import cpu_and_gpu, run_on_env_var +from common_utils import cpu_and_gpu, run_on_env_var, needs_cuda from torchvision.prototype import models from torchvision.prototype.models._api import WeightsEnum, Weights from torchvision.prototype.models._utils import handle_legacy_interface @@ -75,10 +75,12 @@ def test_get_weight(name, weight): + TM.get_models_from_module(models.detection) + TM.get_models_from_module(models.quantization) + TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video), + + TM.get_models_from_module(models.video) + + TM.get_models_from_module(models.optical_flow), ) def test_naming_conventions(model_fn): weights_enum = _get_model_weights(model_fn) + print(weights_enum) assert weights_enum is not None assert len(weights_enum) == 0 or hasattr(weights_enum, "default") @@ -117,13 +119,22 @@ def test_video_model(model_fn, dev): TM.test_video_model(model_fn, dev) +@needs_cuda +@pytest.mark.parametrize("model_builder", (models.optical_flow.raft_large, models.optical_flow.raft_small)) +@pytest.mark.parametrize("scripted", (False, True)) +@run_if_test_with_prototype +def test_raft(model_builder, scripted): + TM.test_raft(model_builder, scripted) + + @pytest.mark.parametrize( "model_fn", TM.get_models_from_module(models) + TM.get_models_from_module(models.detection) + TM.get_models_from_module(models.quantization) + TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video), + + TM.get_models_from_module(models.video) + + TM.get_models_from_module(models.optical_flow), ) @pytest.mark.parametrize("dev", cpu_and_gpu()) @run_if_test_with_prototype @@ -145,6 +156,9 @@ def test_old_vs_new_factory(model_fn, dev): "video": { "input_shape": (1, 3, 4, 112, 112), }, + "optical_flow": { + "input_shape": (1, 3, 128, 128), + }, } model_name = model_fn.__name__ module_name = model_fn.__module__.split(".")[-2] diff --git a/torchvision/prototype/models/optical_flow/__init__.py b/torchvision/prototype/models/optical_flow/__init__.py index 9dd32f25dec..9b78f70b768 100644 --- a/torchvision/prototype/models/optical_flow/__init__.py +++ b/torchvision/prototype/models/optical_flow/__init__.py @@ -1 +1 @@ -from .raft import RAFT, raft_large, raft_small +from .raft import RAFT, raft_large, raft_small, Raft_Large_Weights, Raft_Small_Weights diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py index a459af5b73f..4dad4b3b6b1 100644 --- a/torchvision/prototype/models/optical_flow/raft.py +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -4,9 +4,11 @@ from torch.nn.modules.instancenorm import InstanceNorm2d from torchvision.models.optical_flow import RAFT from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock + # from torchvision.prototype.transforms import RaftEval from .._api import WeightsEnum + # from .._api import Weights from .._utils import handle_legacy_interface @@ -15,6 +17,8 @@ "RAFT", "raft_large", "raft_small", + "Raft_Large_Weights", + "Raft_Small_Weights", ) @@ -86,9 +90,6 @@ def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, * nn.Module: The model. """ - if weights is not None: - raise ValueError("No checkpoint is available for raft_large") - weights = Raft_Large_Weights.verify(weights) return _raft( @@ -135,9 +136,6 @@ def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, * """ - if weights is not None: - raise ValueError("No checkpoint is available for raft_small") - weights = Raft_Small_Weights.verify(weights) return _raft( diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py index f18395f4063..d7c4ddb4684 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/prototype/transforms/_presets.py @@ -6,7 +6,7 @@ from ...transforms import functional as F, InterpolationMode -__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval"] +__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval", "RaftEval"] class CocoEval(nn.Module): From 814c20c918bf9ba0dff91d80188e3ff1b1e2d776 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 8 Dec 2021 09:06:16 +0000 Subject: [PATCH 9/9] Update test/test_prototype_models.py Co-authored-by: Vasilis Vryniotis --- test/test_prototype_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index ae577f31b3e..baf26a1faaf 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -120,7 +120,7 @@ def test_video_model(model_fn, dev): @needs_cuda -@pytest.mark.parametrize("model_builder", (models.optical_flow.raft_large, models.optical_flow.raft_small)) +@pytest.mark.parametrize("model_builder", TM.get_models_from_module(models.optical_flow)) @pytest.mark.parametrize("scripted", (False, True)) @run_if_test_with_prototype def test_raft(model_builder, scripted):