Skip to content

Commit a33a0aa

Browse files
committed
add grid transform zero fill shortcut
1 parent f69eee6 commit a33a0aa

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,13 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
570570
# Apply same grid to a batch of images
571571
grid = grid.expand(squashed_batch_size, -1, -1, -1)
572572

573+
if fill is not None and not isinstance(fill, (tuple, list)):
574+
fill = [float(fill)]
575+
576+
# filling with zeros is the default behavior and thus we can skip the extra fill handling
577+
if fill is not None and all(f == 0 for f in fill):
578+
fill = None
579+
573580
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
574581
if fill is not None:
575582
mask = torch.ones(
@@ -583,8 +590,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
583590
if fill is not None:
584591
float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
585592
mask = mask.expand_as(float_img)
586-
fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type]
587-
fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
593+
fill_img = torch.tensor(fill, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
588594
if mode == "nearest":
589595
bool_mask = mask < 0.5
590596
float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]

0 commit comments

Comments
 (0)