diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 78f29b64ecc..56e91bb3d48 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -3,7 +3,7 @@ import pytest import test_models as TM import torch -from common_utils import cpu_and_gpu, run_on_env_var +from common_utils import cpu_and_gpu, run_on_env_var, needs_cuda from torchvision.prototype import models from torchvision.prototype.models._api import WeightsEnum, Weights from torchvision.prototype.models._utils import handle_legacy_interface @@ -75,10 +75,12 @@ def test_get_weight(name, weight): + TM.get_models_from_module(models.detection) + TM.get_models_from_module(models.quantization) + TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video), + + TM.get_models_from_module(models.video) + + TM.get_models_from_module(models.optical_flow), ) def test_naming_conventions(model_fn): weights_enum = _get_model_weights(model_fn) + print(weights_enum) assert weights_enum is not None assert len(weights_enum) == 0 or hasattr(weights_enum, "default") @@ -149,13 +151,22 @@ def test_video_model(model_fn, dev): TM.test_video_model(model_fn, dev) +@needs_cuda +@pytest.mark.parametrize("model_builder", TM.get_models_from_module(models.optical_flow)) +@pytest.mark.parametrize("scripted", (False, True)) +@run_if_test_with_prototype +def test_raft(model_builder, scripted): + TM.test_raft(model_builder, scripted) + + @pytest.mark.parametrize( "model_fn", TM.get_models_from_module(models) + TM.get_models_from_module(models.detection) + TM.get_models_from_module(models.quantization) + TM.get_models_from_module(models.segmentation) - + TM.get_models_from_module(models.video), + + TM.get_models_from_module(models.video) + + TM.get_models_from_module(models.optical_flow), ) @pytest.mark.parametrize("dev", cpu_and_gpu()) @run_if_test_with_prototype @@ -177,6 +188,9 @@ def test_old_vs_new_factory(model_fn, dev): "video": { "input_shape": (1, 3, 4, 112, 112), }, + "optical_flow": { + "input_shape": (1, 3, 128, 128), + }, } model_name = model_fn.__name__ module_name = model_fn.__module__.split(".")[-2] diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index ba1cc8499d8..f653895598f 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -585,7 +585,7 @@ def raft_large(*, pretrained=False, progress=True, **kwargs): """ if pretrained: - raise ValueError("Pretrained weights aren't available yet") + raise ValueError("No checkpoint is available for raft_large") return _raft( # Feature encoder @@ -631,7 +631,7 @@ def raft_small(*, pretrained=False, progress=True, **kwargs): """ if pretrained: - raise ValueError("Pretrained weights aren't available yet") + raise ValueError("No checkpoint is available for raft_small") return _raft( # Feature encoder diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index 12a4738e53c..bfa44ffa720 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -12,6 +12,7 @@ from .vgg import * from .vision_transformer import * from . import detection +from . import optical_flow from . import quantization from . import segmentation from . import video diff --git a/torchvision/prototype/models/optical_flow/__init__.py b/torchvision/prototype/models/optical_flow/__init__.py new file mode 100644 index 00000000000..9b78f70b768 --- /dev/null +++ b/torchvision/prototype/models/optical_flow/__init__.py @@ -0,0 +1 @@ +from .raft import RAFT, raft_large, raft_small, Raft_Large_Weights, Raft_Small_Weights diff --git a/torchvision/prototype/models/optical_flow/raft.py b/torchvision/prototype/models/optical_flow/raft.py new file mode 100644 index 00000000000..4dad4b3b6b1 --- /dev/null +++ b/torchvision/prototype/models/optical_flow/raft.py @@ -0,0 +1,166 @@ +from typing import Optional + +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.instancenorm import InstanceNorm2d +from torchvision.models.optical_flow import RAFT +from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock + +# from torchvision.prototype.transforms import RaftEval + +from .._api import WeightsEnum + +# from .._api import Weights +from .._utils import handle_legacy_interface + + +__all__ = ( + "RAFT", + "raft_large", + "raft_small", + "Raft_Large_Weights", + "Raft_Small_Weights", +) + + +class Raft_Large_Weights(WeightsEnum): + pass + # C_T_V1 = Weights( + # # Chairs + Things + # url="", + # transforms=RaftEval, + # meta={ + # "recipe": "", + # "epe": -1234, + # }, + # ) + + # C_T_SKHT_V1 = Weights( + # # Chairs + Things + Sintel fine-tuning, i.e.: + # # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + # # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel + # url="", + # transforms=RaftEval, + # meta={ + # "recipe": "", + # "epe": -1234, + # }, + # ) + + # C_T_SKHT_K_V1 = Weights( + # # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: + # # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti + # # Same as CT_SKHT with extra fine-tuning on Kitti + # # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti + # url="", + # transforms=RaftEval, + # meta={ + # "recipe": "", + # "epe": -1234, + # }, + # ) + + # default = C_T_V1 + + +class Raft_Small_Weights(WeightsEnum): + pass + # C_T_V1 = Weights( + # url="", # TODO + # transforms=RaftEval, + # meta={ + # "recipe": "", + # "epe": -1234, + # }, + # ) + # default = C_T_V1 + + +@handle_legacy_interface(weights=("pretrained", None)) +def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): + """RAFT model from + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. + + Args: + weights(Raft_Large_weights, optinal): TODO not implemented yet + progress (bool): If True, displays a progress bar of the download to stderr + kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class + to override any default. + + Returns: + nn.Module: The model. + """ + + weights = Raft_Large_Weights.verify(weights) + + return _raft( + # Feature encoder + feature_encoder_layers=(64, 64, 96, 128, 256), + feature_encoder_block=ResidualBlock, + feature_encoder_norm_layer=InstanceNorm2d, + # Context encoder + context_encoder_layers=(64, 64, 96, 128, 256), + context_encoder_block=ResidualBlock, + context_encoder_norm_layer=BatchNorm2d, + # Correlation block + corr_block_num_levels=4, + corr_block_radius=4, + # Motion encoder + motion_encoder_corr_layers=(256, 192), + motion_encoder_flow_layers=(128, 64), + motion_encoder_out_channels=128, + # Recurrent block + recurrent_block_hidden_state_size=128, + recurrent_block_kernel_size=((1, 5), (5, 1)), + recurrent_block_padding=((0, 2), (2, 0)), + # Flow head + flow_head_hidden_size=256, + # Mask predictor + use_mask_predictor=True, + **kwargs, + ) + + +@handle_legacy_interface(weights=("pretrained", None)) +def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): + """RAFT "small" model from + `RAFT: Recurrent All Pairs Field Transforms for Optical Flow `_. + + Args: + weights(Raft_Small_weights, optinal): TODO not implemented yet + progress (bool): If True, displays a progress bar of the download to stderr + kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class + to override any default. + + Returns: + nn.Module: The model. + + """ + + weights = Raft_Small_Weights.verify(weights) + + return _raft( + # Feature encoder + feature_encoder_layers=(32, 32, 64, 96, 128), + feature_encoder_block=BottleneckBlock, + feature_encoder_norm_layer=InstanceNorm2d, + # Context encoder + context_encoder_layers=(32, 32, 64, 96, 160), + context_encoder_block=BottleneckBlock, + context_encoder_norm_layer=None, + # Correlation block + corr_block_num_levels=4, + corr_block_radius=3, + # Motion encoder + motion_encoder_corr_layers=(96,), + motion_encoder_flow_layers=(64, 32), + motion_encoder_out_channels=82, + # Recurrent block + recurrent_block_hidden_state_size=96, + recurrent_block_kernel_size=(3,), + recurrent_block_padding=(1,), + # Flow head + flow_head_hidden_size=128, + # Mask predictor + use_mask_predictor=False, + **kwargs, + ) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index c91542933b8..56cca7b0402 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -3,4 +3,4 @@ from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop from ._misc import Identity, Normalize -from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval +from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py index 3b9d733d8df..d7c4ddb4684 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/prototype/transforms/_presets.py @@ -6,7 +6,7 @@ from ...transforms import functional as F, InterpolationMode -__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval"] +__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval", "RaftEval"] class CocoEval(nn.Module): @@ -97,3 +97,38 @@ def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, target = F.pil_to_tensor(target) target = target.squeeze(0).to(torch.int64) return img, target + + +class RaftEval(nn.Module): + def forward( + self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + + img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask) + + img1 = F.convert_image_dtype(img1, torch.float32) + img2 = F.convert_image_dtype(img2, torch.float32) + + # map [0, 1] into [-1, 1] + img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + img1 = img1.contiguous() + img2 = img2.contiguous() + + return img1, img2, flow, valid_flow_mask + + def _pil_or_numpy_to_tensor( + self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + if not isinstance(img1, Tensor): + img1 = F.pil_to_tensor(img1) + if not isinstance(img2, Tensor): + img2 = F.pil_to_tensor(img2) + + if flow is not None and not isinstance(flow, Tensor): + flow = torch.from_numpy(flow) + if valid_flow_mask is not None and not isinstance(valid_flow_mask, Tensor): + valid_flow_mask = torch.from_numpy(valid_flow_mask) + + return img1, img2, flow, valid_flow_mask