-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add transforms and presets for optical flow models #5026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
85f314d
55e9992
706a0a6
647dbb5
02a2640
ccc0029
f6fe16d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch | ||
import transforms as T | ||
|
||
|
||
class OpticalFlowPresetEval(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.transforms = T.Compose( | ||
[ | ||
T.PILToTensor(), | ||
T.ConvertImageDtype(torch.float32), | ||
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1] | ||
T.ValidateModelInput(), | ||
] | ||
) | ||
|
||
def forward(self, img1, img2, flow, valid): | ||
return self.transforms(img1, img2, flow, valid) | ||
|
||
|
||
class OpticalFlowPresetTrain(torch.nn.Module): | ||
def __init__( | ||
self, | ||
# RandomResizeAndCrop params | ||
crop_size, | ||
min_scale=-0.2, | ||
max_scale=0.5, | ||
stretch_prob=0.8, | ||
# AsymmetricColorJitter params | ||
brightness=0.4, | ||
contrast=0.4, | ||
saturation=0.4, | ||
hue=0.5 / 3.14, | ||
# Random[H,V]Flip params | ||
asymmetric_jitter_prob=0.2, | ||
do_flip=True, | ||
): | ||
super().__init__() | ||
|
||
transforms = [ | ||
T.PILToTensor(), | ||
T.AsymmetricColorJitter( | ||
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob | ||
), | ||
T.RandomResizeAndCrop( | ||
crop_size=crop_size, min_scale=min_scale, max_scale=max_scale, stretch_prob=stretch_prob | ||
), | ||
] | ||
|
||
if do_flip: | ||
transforms += [T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.1)] | ||
|
||
transforms += [ | ||
T.ConvertImageDtype(torch.float32), | ||
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1] | ||
T.RandomErasing(max_erase=2), | ||
T.MakeValidFlowMask(), | ||
T.ValidateModelInput(), | ||
] | ||
self.transforms = T.Compose(transforms) | ||
|
||
def forward(self, img1, img2, flow, valid): | ||
return self.transforms(img1, img2, flow, valid) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
import torch | ||
import torchvision.transforms as T | ||
import torchvision.transforms.functional as F | ||
|
||
|
||
class ValidateModelInput(torch.nn.Module): | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects | ||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
|
||
assert all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None) | ||
assert all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None) | ||
|
||
assert img1.shape == img2.shape | ||
h, w = img1.shape[-2:] | ||
if flow is not None: | ||
assert flow.shape == (2, h, w) | ||
if valid_flow_mask is not None: | ||
assert valid_flow_mask.shape == (h, w) | ||
assert valid_flow_mask.dtype == torch.bool | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class MakeValidFlowMask(torch.nn.Module): | ||
# This transform generates a valid_flow_mask if it doesn't exist. | ||
# The flow is considered valid if ||flow||_inf < threshold | ||
# This is a noop for Kitti and HD1K which already come with a built-in flow mask. | ||
def __init__(self, threshold=1000): | ||
super().__init__() | ||
self.threshold = threshold | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
if flow is not None and valid_flow_mask is None: | ||
valid_flow_mask = (flow.abs() < self.threshold).all(axis=0) | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class ConvertImageDtype(torch.nn.Module): | ||
def __init__(self, dtype): | ||
super().__init__() | ||
self.dtype = dtype | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
img1 = F.convert_image_dtype(img1, dtype=self.dtype) | ||
img2 = F.convert_image_dtype(img2, dtype=self.dtype) | ||
|
||
img1 = img1.contiguous() | ||
img2 = img2.contiguous() | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class Normalize(torch.nn.Module): | ||
def __init__(self, mean, std): | ||
super().__init__() | ||
self.mean = mean | ||
self.std = std | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
img1 = F.normalize(img1, mean=self.mean, std=self.std) | ||
img2 = F.normalize(img2, mean=self.mean, std=self.std) | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class PILToTensor(torch.nn.Module): | ||
# Converts all inputs to tensors | ||
# Technically the flow and the valid mask are numpy arrays, not PIL images, but we keep that naming | ||
# for consistency with the rest, e.g. the segmentation reference. | ||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
img1 = F.pil_to_tensor(img1) | ||
img2 = F.pil_to_tensor(img2) | ||
if flow is not None: | ||
flow = torch.from_numpy(flow) | ||
if valid_flow_mask is not None: | ||
valid_flow_mask = torch.from_numpy(valid_flow_mask) | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class AsymmetricColorJitter(T.ColorJitter): | ||
# p determines the proba of doing asymmertric vs symmetric color jittering | ||
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, p=0.2): | ||
super().__init__(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) | ||
self.p = p | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
|
||
if torch.rand(1) < self.p: | ||
# asymmetric: different transform for img1 and img2 | ||
img1 = super().forward(img1) | ||
img2 = super().forward(img2) | ||
else: | ||
# symmetric: same transform for img1 and img2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @NicolasHug: So does the @pmeier: Could you please check this strange transform to confirm it's supported by the new Transforms API? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As it stands, this would not be supported. A transform always treats a sample as atomic unit and so multiple images in the same sample would be transformed with the same parameters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I'll clarify. Ultimately this is a special case of
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @NicolasHug Sounds good, just add comments. No need to use RandomApply here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pmeier No worries, this is why we give the option for someone to write custom transforms without the magic of the new API. For weird cases like this. Could you now confirm that this is indeed a workaround we can apply? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm guessing I think this is one of the cases @datumbox mentioned where we need to circumvent the automatic dispatch a little. In case we want to transform both samples separately, we could split the sample and and perform the transformation once for the sample minus image 2 and once for image2. The problem I see with this, is that it can't be automated without assumptions about how the sample is structured. So we either need to use the same structure for every dataset (for example flat dictionary with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Each transform would receive the entire input (which IIRC is a dict) and operate on a subset of that dict. Are you suggesting that img1 and img2 would be concatenated? |
||
batch = torch.stack([img1, img2]) | ||
batch = super().forward(batch) | ||
img1, img2 = batch[0], batch[1] | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class RandomErasing(T.RandomErasing): | ||
# This only erases img2, and with an extra max_erase param | ||
# This max_erase is needed because in the RAFT training ref does: | ||
# 0 erasing with .5 proba | ||
# 1 erase with .25 proba | ||
# 2 erase with .25 proba | ||
# and there's no accurate way to achieve this otherwise. | ||
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1): | ||
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace) | ||
self.max_erase = max_erase | ||
assert self.max_erase > 0 | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
if torch.rand(1) > self.p: | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
for _ in range(torch.randint(self.max_erase, size=(1,)).item()): | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
x, y, h, w, v = self.get_params(img2, scale=self.scale, ratio=self.ratio, value=[self.value]) | ||
img2 = F.erase(img2, x, y, h, w, v, self.inplace) | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class RandomHorizontalFlip(T.RandomHorizontalFlip): | ||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
if torch.rand(1) > self.p: | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
img1 = F.hflip(img1) | ||
img2 = F.hflip(img2) | ||
flow = F.hflip(flow) * torch.tensor([-1, 1])[:, None, None] | ||
if valid_flow_mask is not None: | ||
valid_flow_mask = F.hflip(valid_flow_mask) | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class RandomVerticalFlip(T.RandomVerticalFlip): | ||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
if torch.rand(1) > self.p: | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
img1 = F.vflip(img1) | ||
img2 = F.vflip(img2) | ||
flow = F.vflip(flow) * torch.tensor([1, -1])[:, None, None] | ||
if valid_flow_mask is not None: | ||
valid_flow_mask = F.vflip(valid_flow_mask) | ||
return img1, img2, flow, valid_flow_mask | ||
|
||
|
||
class RandomResizeAndCrop(torch.nn.Module): | ||
# This transform will resize the input with a given proba, and then crop it. | ||
# These are the reversed operations of the built-in RandomResizedCrop, | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# although the order of the operations doesn't matter too much: resizing a | ||
# crop would give the same result as cropping a resized image, up to | ||
# interpolation artifact at the borders of the output. | ||
# | ||
# The reason we don't rely on RandomResizedCrop is because of a significant | ||
# difference in the parametrization of both transforms, in particular, | ||
# because of the way the random parameters are sampled in both transforms, | ||
# which leads to fairly different resuts (and different epe). For more details see | ||
# https://github.com/pytorch/vision/pull/5026/files#r762932579 | ||
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, stretch_prob=0.8): | ||
super().__init__() | ||
self.crop_size = crop_size | ||
self.min_scale = min_scale | ||
self.max_scale = max_scale | ||
self.stretch_prob = stretch_prob | ||
self.resize_prob = 0.8 | ||
self.max_stretch = 0.2 | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
# randomly sample scale | ||
h, w = img1.shape[-2:] | ||
# Note: in original code, they use + 1 instead of + 8 for sparse datasets (e.g. Kitti) | ||
# It shouldn't matter much | ||
min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w) | ||
|
||
scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item() | ||
scale_x = scale | ||
scale_y = scale | ||
if torch.rand(1) < self.stretch_prob: | ||
scale_x *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item() | ||
scale_y *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item() | ||
|
||
scale_x = max(scale_x, min_scale) | ||
scale_y = max(scale_y, min_scale) | ||
|
||
new_h, new_w = round(h * scale_y), round(w * scale_x) | ||
|
||
if torch.rand(1).item() < self.resize_prob: | ||
# rescale the images | ||
img1 = F.resize(img1, size=(new_h, new_w)) | ||
img2 = F.resize(img2, size=(new_h, new_w)) | ||
if valid_flow_mask is None: | ||
flow = F.resize(flow, size=(new_h, new_w)) | ||
flow = flow * torch.tensor([scale_x, scale_y])[:, None, None] | ||
else: | ||
flow, valid_flow_mask = self._resize_sparse_flow( | ||
flow, valid_flow_mask, scale_x=scale_x, scale_y=scale_y | ||
) | ||
|
||
# Note: For sparse datasets (Kitti), the original code uses a "margin" | ||
# See e.g. https://github.com/princeton-vl/RAFT/blob/master/core/utils/augmentor.py#L220:L220 | ||
# We don't, not sure it matters much | ||
y0 = torch.randint(0, img1.shape[1] - self.crop_size[0], size=(1,)).item() | ||
x0 = torch.randint(0, img1.shape[2] - self.crop_size[1], size=(1,)).item() | ||
|
||
img1 = F.crop(img1, y0, x0, self.crop_size[0], self.crop_size[1]) | ||
img2 = F.crop(img2, y0, x0, self.crop_size[0], self.crop_size[1]) | ||
flow = F.crop(flow, y0, x0, self.crop_size[0], self.crop_size[1]) | ||
if valid_flow_mask is not None: | ||
valid_flow_mask = F.crop(valid_flow_mask, y0, x0, self.crop_size[0], self.crop_size[1]) | ||
|
||
return img1, img2, flow, valid_flow_mask | ||
|
||
def _resize_sparse_flow(self, flow, valid_flow_mask, scale_x=1.0, scale_y=1.0): | ||
# This resizes both the flow and the valid_flow_mask mask (which is assumed to be reasonably sparse) | ||
# There are as-many non-zero values in the original flow as in the resized flow (up to OOB) | ||
# So for example if scale_x = scale_y = 2, the sparsity of the output flow is multiplied by 4 | ||
|
||
h, w = flow.shape[-2:] | ||
|
||
h_new = int(round(h * scale_y)) | ||
w_new = int(round(w * scale_x)) | ||
flow_new = torch.zeros(size=[2, h_new, w_new], dtype=flow.dtype) | ||
valid_new = torch.zeros(size=[h_new, w_new], dtype=valid_flow_mask.dtype) | ||
|
||
jj, ii = torch.meshgrid(torch.arange(w), torch.arange(h), indexing="xy") | ||
|
||
ii_valid, jj_valid = ii[valid_flow_mask], jj[valid_flow_mask] | ||
|
||
ii_valid_new = torch.round(ii_valid.to(float) * scale_y).to(torch.long) | ||
jj_valid_new = torch.round(jj_valid.to(float) * scale_x).to(torch.long) | ||
|
||
within_bounds_mask = (0 <= ii_valid_new) & (ii_valid_new < h_new) & (0 <= jj_valid_new) & (jj_valid_new < w_new) | ||
|
||
ii_valid = ii_valid[within_bounds_mask] | ||
jj_valid = jj_valid[within_bounds_mask] | ||
ii_valid_new = ii_valid_new[within_bounds_mask] | ||
jj_valid_new = jj_valid_new[within_bounds_mask] | ||
|
||
valid_flow_new = flow[:, ii_valid, jj_valid] | ||
valid_flow_new[0] *= scale_x | ||
valid_flow_new[1] *= scale_y | ||
|
||
flow_new[:, ii_valid_new, jj_valid_new] = valid_flow_new | ||
valid_new[ii_valid_new, jj_valid_new] = 1 | ||
|
||
return flow_new, valid_new | ||
|
||
|
||
class Compose(torch.nn.Module): | ||
def __init__(self, transforms): | ||
super().__init__() | ||
self.transforms = transforms | ||
|
||
def forward(self, img1, img2, flow, valid_flow_mask): | ||
for t in self.transforms: | ||
img1, img2, flow, valid_flow_mask = t(img1, img2, flow, valid_flow_mask) | ||
return img1, img2, flow, valid_flow_mask |
Uh oh!
There was an error while loading. Please reload this page.