Skip to content

Add raft builders and presets in prototypes #5043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import test_models as TM
import torch
from common_utils import cpu_and_gpu, run_on_env_var
from common_utils import cpu_and_gpu, run_on_env_var, needs_cuda
from torchvision.prototype import models
from torchvision.prototype.models._api import WeightsEnum, Weights
from torchvision.prototype.models._utils import handle_legacy_interface
Expand Down Expand Up @@ -75,10 +75,12 @@ def test_get_weight(name, weight):
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video),
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
)
def test_naming_conventions(model_fn):
weights_enum = _get_model_weights(model_fn)
print(weights_enum)
assert weights_enum is not None
assert len(weights_enum) == 0 or hasattr(weights_enum, "default")

Expand Down Expand Up @@ -149,13 +151,22 @@ def test_video_model(model_fn, dev):
TM.test_video_model(model_fn, dev)


@needs_cuda
@pytest.mark.parametrize("model_builder", TM.get_models_from_module(models.optical_flow))
@pytest.mark.parametrize("scripted", (False, True))
@run_if_test_with_prototype
def test_raft(model_builder, scripted):
TM.test_raft(model_builder, scripted)


@pytest.mark.parametrize(
"model_fn",
TM.get_models_from_module(models)
+ TM.get_models_from_module(models.detection)
+ TM.get_models_from_module(models.quantization)
+ TM.get_models_from_module(models.segmentation)
+ TM.get_models_from_module(models.video),
+ TM.get_models_from_module(models.video)
+ TM.get_models_from_module(models.optical_flow),
)
@pytest.mark.parametrize("dev", cpu_and_gpu())
@run_if_test_with_prototype
Expand All @@ -177,6 +188,9 @@ def test_old_vs_new_factory(model_fn, dev):
"video": {
"input_shape": (1, 3, 4, 112, 112),
},
"optical_flow": {
"input_shape": (1, 3, 128, 128),
},
}
model_name = model_fn.__name__
module_name = model_fn.__module__.split(".")[-2]
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def raft_large(*, pretrained=False, progress=True, **kwargs):
"""

if pretrained:
raise ValueError("Pretrained weights aren't available yet")
raise ValueError("No checkpoint is available for raft_large")

return _raft(
# Feature encoder
Expand Down Expand Up @@ -631,7 +631,7 @@ def raft_small(*, pretrained=False, progress=True, **kwargs):
"""

if pretrained:
raise ValueError("Pretrained weights aren't available yet")
raise ValueError("No checkpoint is available for raft_small")

return _raft(
# Feature encoder
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .vgg import *
from .vision_transformer import *
from . import detection
from . import optical_flow
from . import quantization
from . import segmentation
from . import video
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/optical_flow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .raft import RAFT, raft_large, raft_small, Raft_Large_Weights, Raft_Small_Weights
166 changes: 166 additions & 0 deletions torchvision/prototype/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from typing import Optional

from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.models.optical_flow import RAFT
from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock

# from torchvision.prototype.transforms import RaftEval

from .._api import WeightsEnum

# from .._api import Weights
from .._utils import handle_legacy_interface


__all__ = (
"RAFT",
"raft_large",
"raft_small",
"Raft_Large_Weights",
"Raft_Small_Weights",
)


class Raft_Large_Weights(WeightsEnum):
pass
# C_T_V1 = Weights(
# # Chairs + Things
# url="",
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )

# C_T_SKHT_V1 = Weights(
# # Chairs + Things + Sintel fine-tuning, i.e.:
# # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean)
# # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel
# url="",
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )

# C_T_SKHT_K_V1 = Weights(
# # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.:
# # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti
# # Same as CT_SKHT with extra fine-tuning on Kitti
# # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti
# url="",
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )

# default = C_T_V1


class Raft_Small_Weights(WeightsEnum):
pass
# C_T_V1 = Weights(
# url="", # TODO
# transforms=RaftEval,
# meta={
# "recipe": "",
# "epe": -1234,
# },
# )
# default = C_T_V1


@handle_legacy_interface(weights=("pretrained", None))
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs):
"""RAFT model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.

Args:
weights(Raft_Large_weights, optinal): TODO not implemented yet
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.

Returns:
nn.Module: The model.
"""

weights = Raft_Large_Weights.verify(weights)

return _raft(
# Feature encoder
feature_encoder_layers=(64, 64, 96, 128, 256),
feature_encoder_block=ResidualBlock,
feature_encoder_norm_layer=InstanceNorm2d,
# Context encoder
context_encoder_layers=(64, 64, 96, 128, 256),
context_encoder_block=ResidualBlock,
context_encoder_norm_layer=BatchNorm2d,
# Correlation block
corr_block_num_levels=4,
corr_block_radius=4,
# Motion encoder
motion_encoder_corr_layers=(256, 192),
motion_encoder_flow_layers=(128, 64),
motion_encoder_out_channels=128,
# Recurrent block
recurrent_block_hidden_state_size=128,
recurrent_block_kernel_size=((1, 5), (5, 1)),
recurrent_block_padding=((0, 2), (2, 0)),
# Flow head
flow_head_hidden_size=256,
# Mask predictor
use_mask_predictor=True,
**kwargs,
)


@handle_legacy_interface(weights=("pretrained", None))
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs):
"""RAFT "small" model from
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_.

Args:
weights(Raft_Small_weights, optinal): TODO not implemented yet
progress (bool): If True, displays a progress bar of the download to stderr
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class
to override any default.

Returns:
nn.Module: The model.

"""

weights = Raft_Small_Weights.verify(weights)

return _raft(
# Feature encoder
feature_encoder_layers=(32, 32, 64, 96, 128),
feature_encoder_block=BottleneckBlock,
feature_encoder_norm_layer=InstanceNorm2d,
# Context encoder
context_encoder_layers=(32, 32, 64, 96, 160),
context_encoder_block=BottleneckBlock,
context_encoder_norm_layer=None,
# Correlation block
corr_block_num_levels=4,
corr_block_radius=3,
# Motion encoder
motion_encoder_corr_layers=(96,),
motion_encoder_flow_layers=(64, 32),
motion_encoder_out_channels=82,
# Recurrent block
recurrent_block_hidden_state_size=96,
recurrent_block_kernel_size=(3,),
recurrent_block_padding=(1,),
# Flow head
flow_head_hidden_size=128,
# Mask predictor
use_mask_predictor=False,
**kwargs,
)
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

from ._geometry import Resize, RandomResize, HorizontalFlip, Crop, CenterCrop, RandomCrop
from ._misc import Identity, Normalize
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
37 changes: 36 additions & 1 deletion torchvision/prototype/transforms/_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ...transforms import functional as F, InterpolationMode


__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval"]
__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval", "RaftEval"]


class CocoEval(nn.Module):
Expand Down Expand Up @@ -97,3 +97,38 @@ def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor,
target = F.pil_to_tensor(target)
target = target.squeeze(0).to(torch.int64)
return img, target


class RaftEval(nn.Module):
def forward(
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor]
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:

img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask)

img1 = F.convert_image_dtype(img1, torch.float32)
img2 = F.convert_image_dtype(img2, torch.float32)

# map [0, 1] into [-1, 1]
img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

img1 = img1.contiguous()
img2 = img2.contiguous()

return img1, img2, flow, valid_flow_mask

def _pil_or_numpy_to_tensor(
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor]
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
if not isinstance(img1, Tensor):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it a bit confusing that we type the input as Tensor and actually handle the case where it's not a tensor. I saw that on the other presets so I did the same. I assume that this is only temporary to check these presets on the current transforms (which return PIL images), and that we will remove the conversions eventually?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. It's also because the user is supposed to use these transforms during inference. At that point, you don't know if they chose to read the image with PIL or with TV's io. So here we support both.

This is also done because the reference scripts for other tasks only support PIL. BTW now that you added a prototype section for your model, you should add a support for it in your reference scripts on the other PR.

img1 = F.pil_to_tensor(img1)
if not isinstance(img2, Tensor):
img2 = F.pil_to_tensor(img2)

if flow is not None and not isinstance(flow, Tensor):
flow = torch.from_numpy(flow)
if valid_flow_mask is not None and not isinstance(valid_flow_mask, Tensor):
valid_flow_mask = torch.from_numpy(valid_flow_mask)

return img1, img2, flow, valid_flow_mask