diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 3dc8813b9f8..73563bccd18 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -394,9 +394,7 @@ def test_resize_int(self, size): @pytest.mark.parametrize( "size", [ - [ - 32, - ], + [32], [32, 32], (32, 32), [34, 35], @@ -412,7 +410,7 @@ def test_resize_scripted(self, dt, size, max_size, interpolation, device): # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) if max_size is not None and len(size) != 1: - pytest.xfail("with max_size, size must be a sequence with 2 elements") + pytest.skip("Size should be an int or a sequence of length 1 if max_size is specified") transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size) s_transform = torch.jit.script(transform) @@ -420,11 +418,7 @@ def test_resize_scripted(self, dt, size, max_size, interpolation, device): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_resize_save(self, tmpdir): - transform = T.Resize( - size=[ - 32, - ] - ) + transform = T.Resize(size=[32]) s_transform = torch.jit.script(transform) s_transform.save(os.path.join(tmpdir, "t_resize.pt")) @@ -435,12 +429,8 @@ def test_resize_save(self, tmpdir): "size", [ (32,), - [ - 44, - ], - [ - 32, - ], + [44], + [32], [32, 32], (32, 32), [44, 55], diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ac0e8e0eb13..f1d51fded82 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -42,6 +42,8 @@ def resize_image_tensor( max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> torch.Tensor: + # TODO: use _compute_output_size to enable max_size option + max_size # ununsed right now new_height, new_width = size num_channels, old_height, old_width = get_dimensions_image_tensor(image) batch_shape = image.shape[:-3] @@ -49,7 +51,6 @@ def resize_image_tensor( image.reshape((-1, num_channels, old_height, old_width)), size=size, interpolation=interpolation.value, - max_size=max_size, antialias=antialias, ).reshape(batch_shape + (num_channels, new_height, new_width)) @@ -60,7 +61,9 @@ def resize_image_pil( interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, ) -> PIL.Image.Image: - return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation], max_size=max_size) + # TODO: use _compute_output_size to enable max_size option + max_size # ununsed right now + return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation]) def resize_segmentation_mask( diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index fb4c7e6677d..2a4a7f1b6dd 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -360,6 +360,29 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace) +def _compute_output_size(image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None) -> List[int]: + if len(size) == 1: # specified size only for the smallest edge + h, w = image_size + short, long = (w, h) if w <= h else (h, w) + requested_new_short = size if isinstance(size, int) else size[0] + + new_short, new_long = requested_new_short, int(requested_new_short * long / short) + + if max_size is not None: + if max_size <= requested_new_short: + raise ValueError( + f"max_size = {max_size} must be strictly greater than the requested " + f"size for the smaller edge size = {size}" + ) + if new_long > max_size: + new_short, new_long = int(max_size * new_short / new_long), max_size + + new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) + else: # specified both h and w + new_w, new_h = size[1], size[0] + return [new_h, new_w] + + def resize( img: Tensor, size: List[int], @@ -423,13 +446,32 @@ def resize( if not isinstance(interpolation, InterpolationMode): raise TypeError("Argument interpolation should be a InterpolationMode") + if isinstance(size, (list, tuple)): + if len(size) not in [1, 2]: + raise ValueError( + f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list" + ) + if max_size is not None and len(size) != 1: + raise ValueError( + "max_size should only be passed if size specifies the length of the smaller edge, " + "i.e. size should be an int or a sequence of length 1 in torchscript mode." + ) + + _, image_height, image_width = get_dimensions(img) + if isinstance(size, int): + size = [size] + output_size = _compute_output_size((image_height, image_width), size, max_size) + + if (image_height, image_width) == output_size: + return img + if not isinstance(img, torch.Tensor): if antialias is not None and not antialias: warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") pil_interpolation = pil_modes_mapping[interpolation] - return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size) + return F_pil.resize(img, size=output_size, interpolation=pil_interpolation) - return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias) + return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias) def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 0203ee4495b..93bdeb8f308 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -1,5 +1,5 @@ import numbers -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -240,46 +240,16 @@ def crop( @torch.jit.unused def resize( img: Image.Image, - size: Union[Sequence[int], int], + size: Union[List[int], int], interpolation: int = _pil_constants.BILINEAR, - max_size: Optional[int] = None, ) -> Image.Image: if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") - if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))): + if not (isinstance(size, list) and len(size) == 2): raise TypeError(f"Got inappropriate size arg: {size}") - if isinstance(size, Sequence) and len(size) == 1: - size = size[0] - if isinstance(size, int): - w, h = img.size - - short, long = (w, h) if w <= h else (h, w) - new_short, new_long = size, int(size * long / short) - - if max_size is not None: - if max_size <= size: - raise ValueError( - f"max_size = {max_size} must be strictly greater than the requested " - f"size for the smaller edge size = {size}" - ) - if new_long > max_size: - new_short, new_long = int(max_size * new_short / new_long), max_size - - new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) - - if (w, h) == (new_w, new_h): - return img - else: - return img.resize((new_w, new_h), interpolation) - else: - if max_size is not None: - raise ValueError( - "max_size should only be passed if size specifies the length of the smaller edge, " - "i.e. size should be an int or a sequence of length 1 in torchscript mode." - ) - return img.resize(size[::-1], interpolation) + return img.resize(size[::-1], interpolation) @torch.jit.unused diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 1899caebfc3..acc8d3ae3e1 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -430,70 +430,25 @@ def resize( img: Tensor, size: List[int], interpolation: str = "bilinear", - max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> Tensor: _assert_image_tensor(img) - if not isinstance(size, (int, tuple, list)): - raise TypeError("Got inappropriate size arg") - if not isinstance(interpolation, str): - raise TypeError("Got inappropriate interpolation arg") - - if interpolation not in ["nearest", "bilinear", "bicubic"]: - raise ValueError("This interpolation mode is unsupported with Tensor input") - if isinstance(size, tuple): size = list(size) - if isinstance(size, list): - if len(size) not in [1, 2]: - raise ValueError( - f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list" - ) - if max_size is not None and len(size) != 1: - raise ValueError( - "max_size should only be passed if size specifies the length of the smaller edge, " - "i.e. size should be an int or a sequence of length 1 in torchscript mode." - ) - if antialias is None: antialias = False if antialias and interpolation not in ["bilinear", "bicubic"]: raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") - _, h, w = get_dimensions(img) - - if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge - short, long = (w, h) if w <= h else (h, w) - requested_new_short = size if isinstance(size, int) else size[0] - - new_short, new_long = requested_new_short, int(requested_new_short * long / short) - - if max_size is not None: - if max_size <= requested_new_short: - raise ValueError( - f"max_size = {max_size} must be strictly greater than the requested " - f"size for the smaller edge size = {size}" - ) - if new_long > max_size: - new_short, new_long = int(max_size * new_short / new_long), max_size - - new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) - - if (w, h) == (new_w, new_h): - return img - - else: # specified both h and w - new_w, new_h = size[1], size[0] - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) # Define align_corners to avoid warnings align_corners = False if interpolation in ["bilinear", "bicubic"] else None - img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners, antialias=antialias) + img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias) if interpolation == "bicubic" and out_dtype == torch.uint8: img = img.clamp(min=0, max=255)