Skip to content

Commit ccc0029

Browse files
committed
Address comments
1 parent 02a2640 commit ccc0029

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

references/optical_flow/transforms.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,13 @@ def forward(self, img1, img2, flow, valid_flow_mask):
101101

102102
class RandomErasing(T.RandomErasing):
103103
# This only erases img2, and with an extra max_erase param
104+
# This max_erase is needed because in the RAFT training ref does:
105+
# 0 erasing with .5 proba
106+
# 1 erase with .25 proba
107+
# 2 erase with .25 proba
108+
# and there's no accurate way to achieve this otherwise.
104109
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1):
105-
super().__init__()
110+
super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace)
106111
self.max_erase = max_erase
107112
assert self.max_erase > 0
108113

@@ -171,12 +176,12 @@ def forward(self, img1, img2, flow, valid_flow_mask):
171176
# It shouldn't matter much
172177
min_scale = max((self.crop_size[0] + 8) / h, (self.crop_size[1] + 8) / w)
173178

174-
scale = 2 ** torch.FloatTensor(1).uniform_(self.min_scale, self.max_scale).item()
179+
scale = 2 ** torch.empty(1, dtype=torch.float32).uniform_(self.min_scale, self.max_scale).item()
175180
scale_x = scale
176181
scale_y = scale
177182
if torch.rand(1) < self.stretch_prob:
178-
scale_x *= 2 ** torch.FloatTensor(1).uniform_(-self.max_stretch, self.max_stretch).item()
179-
scale_y *= 2 ** torch.FloatTensor(1).uniform_(-self.max_stretch, self.max_stretch).item()
183+
scale_x *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item()
184+
scale_y *= 2 ** torch.empty(1, dtype=torch.float32).uniform_(-self.max_stretch, self.max_stretch).item()
180185

181186
scale_x = max(scale_x, min_scale)
182187
scale_y = max(scale_y, min_scale)
@@ -245,8 +250,9 @@ def _resize_sparse_flow(self, flow, valid_flow_mask, scale_x=1.0, scale_y=1.0):
245250
return flow_new, valid_new
246251

247252

248-
class Compose:
253+
class Compose(torch.nn.Module):
249254
def __init__(self, transforms):
255+
super().__init__()
250256
self.transforms = transforms
251257

252258
def forward(self, img1, img2, flow, valid_flow_mask):

0 commit comments

Comments
 (0)