diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index d6d42344fcb..f3f34fd7fce 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -570,6 +570,13 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill # Apply same grid to a batch of images grid = grid.expand(squashed_batch_size, -1, -1, -1) + if fill is not None and not isinstance(fill, (tuple, list)): + fill = [float(fill)] + + # filling with zeros is the default behavior and thus we can skip the extra fill handling + if fill is not None and all(f == 0 for f in fill): + fill = None + # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice if fill is not None: mask = torch.ones( @@ -583,8 +590,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill if fill is not None: float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3) mask = mask.expand_as(float_img) - fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)] # type: ignore[arg-type] - fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1) + fill_img = torch.tensor(fill, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1) if mode == "nearest": bool_mask = mask < 0.5 float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]