@@ -101,8 +101,13 @@ def forward(self, img1, img2, flow, valid_flow_mask):
101
101
102
102
class RandomErasing (T .RandomErasing ):
103
103
# This only erases img2, and with an extra max_erase param
104
+ # This max_erase is needed because in the RAFT training ref does:
105
+ # 0 erasing with .5 proba
106
+ # 1 erase with .25 proba
107
+ # 2 erase with .25 proba
108
+ # and there's no accurate way to achieve this otherwise.
104
109
def __init__ (self , p = 0.5 , scale = (0.02 , 0.33 ), ratio = (0.3 , 3.3 ), value = 0 , inplace = False , max_erase = 1 ):
105
- super ().__init__ ()
110
+ super ().__init__ (p = p , scale = scale , ratio = ratio , value = value , inplace = inplace )
106
111
self .max_erase = max_erase
107
112
assert self .max_erase > 0
108
113
@@ -171,12 +176,12 @@ def forward(self, img1, img2, flow, valid_flow_mask):
171
176
# It shouldn't matter much
172
177
min_scale = max ((self .crop_size [0 ] + 8 ) / h , (self .crop_size [1 ] + 8 ) / w )
173
178
174
- scale = 2 ** torch .FloatTensor ( 1 ).uniform_ (self .min_scale , self .max_scale ).item ()
179
+ scale = 2 ** torch .empty ( 1 , dtype = torch . float32 ).uniform_ (self .min_scale , self .max_scale ).item ()
175
180
scale_x = scale
176
181
scale_y = scale
177
182
if torch .rand (1 ) < self .stretch_prob :
178
- scale_x *= 2 ** torch .FloatTensor ( 1 ).uniform_ (- self .max_stretch , self .max_stretch ).item ()
179
- scale_y *= 2 ** torch .FloatTensor ( 1 ).uniform_ (- self .max_stretch , self .max_stretch ).item ()
183
+ scale_x *= 2 ** torch .empty ( 1 , dtype = torch . float32 ).uniform_ (- self .max_stretch , self .max_stretch ).item ()
184
+ scale_y *= 2 ** torch .empty ( 1 , dtype = torch . float32 ).uniform_ (- self .max_stretch , self .max_stretch ).item ()
180
185
181
186
scale_x = max (scale_x , min_scale )
182
187
scale_y = max (scale_y , min_scale )
@@ -245,8 +250,9 @@ def _resize_sparse_flow(self, flow, valid_flow_mask, scale_x=1.0, scale_y=1.0):
245
250
return flow_new , valid_new
246
251
247
252
248
- class Compose :
253
+ class Compose ( torch . nn . Module ) :
249
254
def __init__ (self , transforms ):
255
+ super ().__init__ ()
250
256
self .transforms = transforms
251
257
252
258
def forward (self , img1 , img2 , flow , valid_flow_mask ):
0 commit comments