Skip to content

Refactored and modified private api for resize functional op #6191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 5 additions & 15 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,7 @@ def test_resize_int(self, size):
@pytest.mark.parametrize(
"size",
[
[
32,
],
[32],
[32, 32],
(32, 32),
[34, 35],
Expand All @@ -412,19 +410,15 @@ def test_resize_scripted(self, dt, size, max_size, interpolation, device):
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
if max_size is not None and len(size) != 1:
pytest.xfail("with max_size, size must be a sequence with 2 elements")
pytest.skip("Size should be an int or a sequence of length 1 if max_size is specified")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Why skip instead of xfail?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question :)

According to https://docs.pytest.org/en/stable/how-to/skipping.html
We skip in case windows tests on non-windows platform and xfail if we expect a test to fail for some reason.

In our case, the configuration "max_size is not None and len(size) != 1" can not be tested as we explicitly raise and error. Proper solution to that is to catch the error and check the message text with pytest.raises.

I felt like skipping is more appropriate than xfail. But it is a matter of taste if you think it is better to revert to xfail I can do that and will only fix the message which is incorrect

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting a pytest.raises where if I'm honest but no strong opinions.


transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size)
s_transform = torch.jit.script(transform)
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)

def test_resize_save(self, tmpdir):
transform = T.Resize(
size=[
32,
]
)
transform = T.Resize(size=[32])
s_transform = torch.jit.script(transform)
s_transform.save(os.path.join(tmpdir, "t_resize.pt"))

Expand All @@ -435,12 +429,8 @@ def test_resize_save(self, tmpdir):
"size",
[
(32,),
[
44,
],
[
32,
],
[44],
[32],
[32, 32],
(32, 32),
[44, 55],
Expand Down
7 changes: 5 additions & 2 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ def resize_image_tensor(
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> torch.Tensor:
# TODO: use _compute_output_size to enable max_size option
max_size # ununsed right now
new_height, new_width = size
num_channels, old_height, old_width = get_dimensions_image_tensor(image)
batch_shape = image.shape[:-3]
return _FT.resize(
image.reshape((-1, num_channels, old_height, old_width)),
size=size,
interpolation=interpolation.value,
max_size=max_size,
antialias=antialias,
).reshape(batch_shape + (num_channels, new_height, new_width))

Expand All @@ -60,7 +61,9 @@ def resize_image_pil(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
) -> PIL.Image.Image:
return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation], max_size=max_size)
# TODO: use _compute_output_size to enable max_size option
max_size # ununsed right now
return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation])


def resize_segmentation_mask(
Expand Down
46 changes: 44 additions & 2 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,29 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)


def _compute_output_size(image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None) -> List[int]:
if len(size) == 1: # specified size only for the smallest edge
h, w = image_size
short, long = (w, h) if w <= h else (h, w)
requested_new_short = size if isinstance(size, int) else size[0]

new_short, new_long = requested_new_short, int(requested_new_short * long / short)

if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size

new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
else: # specified both h and w
new_w, new_h = size[1], size[0]
return [new_h, new_w]


def resize(
img: Tensor,
size: List[int],
Expand Down Expand Up @@ -423,13 +446,32 @@ def resize(
if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")

if isinstance(size, (list, tuple)):
if len(size) not in [1, 2]:
raise ValueError(
f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
)
if max_size is not None and len(size) != 1:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)

_, image_height, image_width = get_dimensions(img)
if isinstance(size, int):
size = [size]
output_size = _compute_output_size((image_height, image_width), size, max_size)

if (image_height, image_width) == output_size:
return img

if not isinstance(img, torch.Tensor):
if antialias is not None and not antialias:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size)
return F_pil.resize(img, size=output_size, interpolation=pil_interpolation)

return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias)
return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias)


def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
Expand Down
38 changes: 4 additions & 34 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numbers
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -240,46 +240,16 @@ def crop(
@torch.jit.unused
def resize(
img: Image.Image,
size: Union[Sequence[int], int],
size: Union[List[int], int],
interpolation: int = _pil_constants.BILINEAR,
max_size: Optional[int] = None,
) -> Image.Image:

if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}")
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
if not (isinstance(size, list) and len(size) == 2):
raise TypeError(f"Got inappropriate size arg: {size}")

if isinstance(size, Sequence) and len(size) == 1:
size = size[0]
if isinstance(size, int):
w, h = img.size

short, long = (w, h) if w <= h else (h, w)
new_short, new_long = size, int(size * long / short)

if max_size is not None:
if max_size <= size:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size

new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)

if (w, h) == (new_w, new_h):
return img
else:
return img.resize((new_w, new_h), interpolation)
else:
if max_size is not None:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
return img.resize(size[::-1], interpolation)
return img.resize(size[::-1], interpolation)


@torch.jit.unused
Expand Down
47 changes: 1 addition & 46 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,70 +430,25 @@ def resize(
img: Tensor,
size: List[int],
interpolation: str = "bilinear",
max_size: Optional[int] = None,
antialias: Optional[bool] = None,
) -> Tensor:
_assert_image_tensor(img)

if not isinstance(size, (int, tuple, list)):
raise TypeError("Got inappropriate size arg")
if not isinstance(interpolation, str):
raise TypeError("Got inappropriate interpolation arg")

if interpolation not in ["nearest", "bilinear", "bicubic"]:
raise ValueError("This interpolation mode is unsupported with Tensor input")

if isinstance(size, tuple):
size = list(size)

if isinstance(size, list):
if len(size) not in [1, 2]:
raise ValueError(
f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
)
if max_size is not None and len(size) != 1:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)

if antialias is None:
antialias = False

if antialias and interpolation not in ["bilinear", "bicubic"]:
raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")

_, h, w = get_dimensions(img)

if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
short, long = (w, h) if w <= h else (h, w)
requested_new_short = size if isinstance(size, int) else size[0]

new_short, new_long = requested_new_short, int(requested_new_short * long / short)

if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size

new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)

if (w, h) == (new_w, new_h):
return img

else: # specified both h and w
new_w, new_h = size[1], size[0]

img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])

# Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None

img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners, antialias=antialias)
img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)

if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255)
Expand Down