-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add raft builders and presets in prototypes #5043
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a1d5c24
b436fc0
91e4084
1f58b77
d958643
c8d65b6
17522f4
1dc06ba
48674ba
739924c
68d1f41
814c20c
aa66031
6a89b6d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .raft import RAFT, raft_large, raft_small, Raft_Large_Weights, Raft_Small_Weights |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
from typing import Optional | ||
|
||
from torch.nn.modules.batchnorm import BatchNorm2d | ||
from torch.nn.modules.instancenorm import InstanceNorm2d | ||
from torchvision.models.optical_flow import RAFT | ||
from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock | ||
|
||
# from torchvision.prototype.transforms import RaftEval | ||
|
||
from .._api import WeightsEnum | ||
|
||
# from .._api import Weights | ||
from .._utils import handle_legacy_interface | ||
|
||
|
||
__all__ = ( | ||
"RAFT", | ||
"raft_large", | ||
"raft_small", | ||
"Raft_Large_Weights", | ||
"Raft_Small_Weights", | ||
) | ||
|
||
|
||
class Raft_Large_Weights(WeightsEnum): | ||
pass | ||
# C_T_V1 = Weights( | ||
# # Chairs + Things | ||
# url="", | ||
# transforms=RaftEval, | ||
# meta={ | ||
# "recipe": "", | ||
# "epe": -1234, | ||
# }, | ||
# ) | ||
|
||
# C_T_SKHT_V1 = Weights( | ||
# # Chairs + Things + Sintel fine-tuning, i.e.: | ||
# # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) | ||
# # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel | ||
# url="", | ||
# transforms=RaftEval, | ||
# meta={ | ||
# "recipe": "", | ||
# "epe": -1234, | ||
# }, | ||
# ) | ||
|
||
# C_T_SKHT_K_V1 = Weights( | ||
# # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning i.e.: | ||
# # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) + Kitti | ||
# # Same as CT_SKHT with extra fine-tuning on Kitti | ||
# # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti | ||
# url="", | ||
# transforms=RaftEval, | ||
# meta={ | ||
# "recipe": "", | ||
# "epe": -1234, | ||
# }, | ||
# ) | ||
|
||
# default = C_T_V1 | ||
|
||
|
||
class Raft_Small_Weights(WeightsEnum): | ||
pass | ||
# C_T_V1 = Weights( | ||
# url="", # TODO | ||
# transforms=RaftEval, | ||
# meta={ | ||
# "recipe": "", | ||
# "epe": -1234, | ||
# }, | ||
# ) | ||
# default = C_T_V1 | ||
|
||
|
||
@handle_legacy_interface(weights=("pretrained", None)) | ||
def raft_large(*, weights: Optional[Raft_Large_Weights] = None, progress=True, **kwargs): | ||
"""RAFT model from | ||
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_. | ||
|
||
Args: | ||
weights(Raft_Large_weights, optinal): TODO not implemented yet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class | ||
to override any default. | ||
|
||
Returns: | ||
nn.Module: The model. | ||
""" | ||
|
||
weights = Raft_Large_Weights.verify(weights) | ||
|
||
return _raft( | ||
# Feature encoder | ||
feature_encoder_layers=(64, 64, 96, 128, 256), | ||
feature_encoder_block=ResidualBlock, | ||
feature_encoder_norm_layer=InstanceNorm2d, | ||
# Context encoder | ||
context_encoder_layers=(64, 64, 96, 128, 256), | ||
context_encoder_block=ResidualBlock, | ||
context_encoder_norm_layer=BatchNorm2d, | ||
# Correlation block | ||
corr_block_num_levels=4, | ||
corr_block_radius=4, | ||
# Motion encoder | ||
motion_encoder_corr_layers=(256, 192), | ||
motion_encoder_flow_layers=(128, 64), | ||
motion_encoder_out_channels=128, | ||
# Recurrent block | ||
recurrent_block_hidden_state_size=128, | ||
recurrent_block_kernel_size=((1, 5), (5, 1)), | ||
recurrent_block_padding=((0, 2), (2, 0)), | ||
# Flow head | ||
flow_head_hidden_size=256, | ||
# Mask predictor | ||
use_mask_predictor=True, | ||
**kwargs, | ||
) | ||
|
||
|
||
@handle_legacy_interface(weights=("pretrained", None)) | ||
def raft_small(*, weights: Optional[Raft_Small_Weights] = None, progress=True, **kwargs): | ||
"""RAFT "small" model from | ||
`RAFT: Recurrent All Pairs Field Transforms for Optical Flow <https://arxiv.org/abs/2003.12039>`_. | ||
|
||
Args: | ||
weights(Raft_Small_weights, optinal): TODO not implemented yet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
kwargs (dict): Parameters that will be passed to the :class:`~torchvision.models.optical_flow.RAFT` class | ||
to override any default. | ||
|
||
Returns: | ||
nn.Module: The model. | ||
|
||
""" | ||
|
||
weights = Raft_Small_Weights.verify(weights) | ||
|
||
return _raft( | ||
# Feature encoder | ||
feature_encoder_layers=(32, 32, 64, 96, 128), | ||
feature_encoder_block=BottleneckBlock, | ||
feature_encoder_norm_layer=InstanceNorm2d, | ||
# Context encoder | ||
context_encoder_layers=(32, 32, 64, 96, 160), | ||
context_encoder_block=BottleneckBlock, | ||
context_encoder_norm_layer=None, | ||
# Correlation block | ||
corr_block_num_levels=4, | ||
corr_block_radius=3, | ||
# Motion encoder | ||
motion_encoder_corr_layers=(96,), | ||
motion_encoder_flow_layers=(64, 32), | ||
motion_encoder_out_channels=82, | ||
# Recurrent block | ||
recurrent_block_hidden_state_size=96, | ||
recurrent_block_kernel_size=(3,), | ||
recurrent_block_padding=(1,), | ||
# Flow head | ||
flow_head_hidden_size=128, | ||
# Mask predictor | ||
use_mask_predictor=False, | ||
**kwargs, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
from ...transforms import functional as F, InterpolationMode | ||
|
||
|
||
__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval"] | ||
__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval", "RaftEval"] | ||
|
||
|
||
class CocoEval(nn.Module): | ||
|
@@ -97,3 +97,38 @@ def forward(self, img: Tensor, target: Optional[Tensor] = None) -> Tuple[Tensor, | |
target = F.pil_to_tensor(target) | ||
target = target.squeeze(0).to(torch.int64) | ||
return img, target | ||
|
||
|
||
class RaftEval(nn.Module): | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def forward( | ||
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] | ||
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: | ||
|
||
img1, img2, flow, valid_flow_mask = self._pil_or_numpy_to_tensor(img1, img2, flow, valid_flow_mask) | ||
|
||
img1 = F.convert_image_dtype(img1, torch.float32) | ||
img2 = F.convert_image_dtype(img2, torch.float32) | ||
|
||
# map [0, 1] into [-1, 1] | ||
img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | ||
img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | ||
|
||
img1 = img1.contiguous() | ||
img2 = img2.contiguous() | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
def _pil_or_numpy_to_tensor( | ||
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] | ||
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: | ||
if not isinstance(img1, Tensor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find it a bit confusing that we type the input as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. It's also because the user is supposed to use these transforms during inference. At that point, you don't know if they chose to read the image with PIL or with TV's io. So here we support both. This is also done because the reference scripts for other tasks only support PIL. BTW now that you added a prototype section for your model, you should add a support for it in your reference scripts on the other PR. |
||
img1 = F.pil_to_tensor(img1) | ||
if not isinstance(img2, Tensor): | ||
img2 = F.pil_to_tensor(img2) | ||
|
||
if flow is not None and not isinstance(flow, Tensor): | ||
flow = torch.from_numpy(flow) | ||
if valid_flow_mask is not None and not isinstance(valid_flow_mask, Tensor): | ||
valid_flow_mask = torch.from_numpy(valid_flow_mask) | ||
|
||
return img1, img2, flow, valid_flow_mask |
Uh oh!
There was an error while loading. Please reload this page.