@@ -570,6 +570,13 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
570
570
# Apply same grid to a batch of images
571
571
grid = grid .expand (squashed_batch_size , - 1 , - 1 , - 1 )
572
572
573
+ if fill is not None and not isinstance (fill , (tuple , list )):
574
+ fill = [float (fill )]
575
+
576
+ # filling with zeros is the default behavior and thus we can skip the extra fill handling
577
+ if fill is not None and all (f == 0 for f in fill ):
578
+ fill = None
579
+
573
580
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
574
581
if fill is not None :
575
582
mask = torch .ones (
@@ -583,8 +590,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
583
590
if fill is not None :
584
591
float_img , mask = torch .tensor_split (float_img , indices = (- 1 ,), dim = - 3 )
585
592
mask = mask .expand_as (float_img )
586
- fill_list = fill if isinstance (fill , (tuple , list )) else [float (fill )] # type: ignore[arg-type]
587
- fill_img = torch .tensor (fill_list , dtype = float_img .dtype , device = float_img .device ).view (1 , - 1 , 1 , 1 )
593
+ fill_img = torch .tensor (fill , dtype = float_img .dtype , device = float_img .device ).view (1 , - 1 , 1 , 1 )
588
594
if mode == "nearest" :
589
595
bool_mask = mask < 0.5
590
596
float_img [bool_mask ] = fill_img .expand_as (float_img )[bool_mask ]
0 commit comments