Skip to content

Fill arg and _apply_grid_transform improvements #6517

Open
@vfdev-5

Description

@vfdev-5

Few years ago we introduced non-const fill value handling in _apply_grid_transform using mask approach:

# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
img = torch.cat((img, dummy), dim=1)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
# Fill with required color
if fill is not None:
mask = img[:, -1:, :, :] # N * 1 * H * W
img = img[:, :-1, :, :] # N * C * H * W
mask = mask.expand_as(img)
len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1
fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img)
if mode == "nearest":
mask = mask < 0.5
img[mask] = fill_img[mask]
else: # 'bilinear'
img = img * mask + (1.0 - mask) * fill_img

There are few minor problems with this approach:

  1. if we pass fill = [0.0, ], we would expect to have a similar result as fill=None. This is not exactly true for bilinear interpolation mode where we do linear interpolation:
    else: # 'bilinear'
    img = img * mask + (1.0 - mask) * fill_img

Most probably, we would like to skip fill_img creation for all fill values that has sum(fill) == 0 as grid_sample pads with zeros.

- if fill is not None:
+ if fill is not None and sum(fill) > 0:
  1. Linear fill_img and img interpolation may be replaced by directly applying a mask:
         mask = mask < 0.9999
         img[mask] = fill_img[mask] 

That would match better PIL Image behaviour.

else: # 'bilinear'
img = img * mask + (1.0 - mask) * fill_img

image

cc @datumbox

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions