From 25b36671d91445982365b86a867daf765c2ad116 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 09:53:11 +0000 Subject: [PATCH 1/9] Added base tests for rotate_image_tensor --- test/test_prototype_transforms_functional.py | 16 ++++++++++++++++ .../prototype/transforms/functional/_geometry.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index be3932a8b7f..6ec912a4770 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -284,6 +284,22 @@ def affine_segmentation_mask(): ) +@register_kernel_info_from_sample_inputs_fn +def rotate_image_tensor(): + for image, angle, expand, center, fill in itertools.product( + make_images(extra_dims=((), (4,))), + [-87, 15, 90], # angle + [True, False], # expand + [None, [12, 23]], # center + [None, [128]], # fill + ): + if center is not None and expand: + # Skip warning: The provided center argument is ignored if expand is True + continue + + yield SampleInput(image, angle=angle, expand=expand, center=center, fill=fill) + + @register_kernel_info_from_sample_inputs_fn def rotate_bounding_box(): for bounding_box, angle, expand, center in itertools.product( diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ac0e8e0eb13..d71706dbb65 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -318,8 +318,8 @@ def rotate_image_tensor( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[List[float]] = None, center: Optional[List[float]] = None, + fill: Optional[List[float]] = None, ) -> torch.Tensor: center_f = [0.0, 0.0] if center is not None: From 6b3483ddc90625bd780b4e9eca1c88592a9c63c3 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 11:07:45 +0000 Subject: [PATCH 2/9] Updated resize_image_tensor API and tests and fixed a bug with max_size --- test/test_prototype_transforms_functional.py | 17 +++++++++++------ .../transforms/functional/_geometry.py | 7 ++++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 6ec912a4770..5550c99158a 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -201,19 +201,24 @@ def horizontal_flip_bounding_box(): @register_kernel_info_from_sample_inputs_fn def resize_image_tensor(): - for image, interpolation in itertools.product( + for image, interpolation, max_size, antialias in itertools.product( make_images(), - [ - F.InterpolationMode.BILINEAR, - F.InterpolationMode.NEAREST, - ], + [F.InterpolationMode.BILINEAR, F.InterpolationMode.NEAREST], # interpolation + [None, 34], # max_size + [False, True], # antialias ): + + if antialias and interpolation == F.InterpolationMode.NEAREST: + continue + height, width = image.shape[-2:] for size in [ (height, width), (int(height * 0.75), int(width * 1.25)), ]: - yield SampleInput(image, size=size, interpolation=interpolation) + if max_size is not None: + size = [size[0]] + yield SampleInput(image, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) @register_kernel_info_from_sample_inputs_fn diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index d71706dbb65..9b3e370dde8 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -42,16 +42,17 @@ def resize_image_tensor( max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> torch.Tensor: - new_height, new_width = size num_channels, old_height, old_width = get_dimensions_image_tensor(image) batch_shape = image.shape[:-3] - return _FT.resize( + output = _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias, - ).reshape(batch_shape + (num_channels, new_height, new_width)) + ) + num_channels, new_height, new_width = get_dimensions_image_tensor(output) + return output.reshape(batch_shape + (num_channels, new_height, new_width)) def resize_image_pil( From ea7c513ff69dedface1d5e4c1708ae5b6ebe9fdf Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 15:36:11 +0000 Subject: [PATCH 3/9] Refactored and modified private api for resize functional op --- test/test_transforms_tensor.py | 20 +++------ torchvision/transforms/functional.py | 46 +++++++++++++++++++- torchvision/transforms/functional_pil.py | 34 +-------------- torchvision/transforms/functional_tensor.py | 47 +-------------------- 4 files changed, 52 insertions(+), 95 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index ba2321ec455..f0cd3ba0021 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -394,9 +394,7 @@ def test_resize_int(self, size): @pytest.mark.parametrize( "size", [ - [ - 32, - ], + [32], [32, 32], (32, 32), [34, 35], @@ -412,7 +410,7 @@ def test_resize_scripted(self, dt, size, max_size, interpolation, device): # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) if max_size is not None and len(size) != 1: - pytest.xfail("with max_size, size must be a sequence with 2 elements") + pytest.skip("Size should be an int or a sequence of length 1 if max_size is specified") transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size) s_transform = torch.jit.script(transform) @@ -420,11 +418,7 @@ def test_resize_scripted(self, dt, size, max_size, interpolation, device): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_resize_save(self, tmpdir): - transform = T.Resize( - size=[ - 32, - ] - ) + transform = T.Resize(size=[32]) s_transform = torch.jit.script(transform) s_transform.save(os.path.join(tmpdir, "t_resize.pt")) @@ -435,12 +429,8 @@ def test_resize_save(self, tmpdir): "size", [ (32,), - [ - 44, - ], - [ - 32, - ], + [44], + [32], [32, 32], (32, 32), [44, 55], diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index c40ae1eb92b..609c64ad4ff 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -360,6 +360,31 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace) +def _compute_output_size( + image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None +) -> Tuple[int, int]: + if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge + h, w = image_size + short, long = (w, h) if w <= h else (h, w) + requested_new_short = size if isinstance(size, int) else size[0] + + new_short, new_long = requested_new_short, int(requested_new_short * long / short) + + if max_size is not None: + if max_size <= requested_new_short: + raise ValueError( + f"max_size = {max_size} must be strictly greater than the requested " + f"size for the smaller edge size = {size}" + ) + if new_long > max_size: + new_short, new_long = int(max_size * new_short / new_long), max_size + + new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) + else: # specified both h and w + new_w, new_h = size[1], size[0] + return new_h, new_w + + def resize( img: Tensor, size: List[int], @@ -423,13 +448,30 @@ def resize( if not isinstance(interpolation, InterpolationMode): raise TypeError("Argument interpolation should be a InterpolationMode") + if isinstance(size, (list, tuple)): + if len(size) not in [1, 2]: + raise ValueError( + f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list" + ) + if max_size is not None and len(size) != 1: + raise ValueError( + "max_size should only be passed if size specifies the length of the smaller edge, " + "i.e. size should be an int or a sequence of length 1 in torchscript mode." + ) + + _, image_height, image_width = get_dimensions(img) + output_size = _compute_output_size((image_height, image_width), size, max_size) + + if (image_height, image_width) == output_size: + return img + if not isinstance(img, torch.Tensor): if antialias is not None and not antialias: warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") pil_interpolation = pil_modes_mapping[interpolation] - return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size) + return F_pil.resize(img, size=output_size, interpolation=pil_interpolation) - return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias) + return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias) def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 0203ee4495b..3c1a911a5d4 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -242,44 +242,14 @@ def resize( img: Image.Image, size: Union[Sequence[int], int], interpolation: int = _pil_constants.BILINEAR, - max_size: Optional[int] = None, ) -> Image.Image: if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") - if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))): + if not (isinstance(size, Sequence) and len(size) == 2): raise TypeError(f"Got inappropriate size arg: {size}") - if isinstance(size, Sequence) and len(size) == 1: - size = size[0] - if isinstance(size, int): - w, h = img.size - - short, long = (w, h) if w <= h else (h, w) - new_short, new_long = size, int(size * long / short) - - if max_size is not None: - if max_size <= size: - raise ValueError( - f"max_size = {max_size} must be strictly greater than the requested " - f"size for the smaller edge size = {size}" - ) - if new_long > max_size: - new_short, new_long = int(max_size * new_short / new_long), max_size - - new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) - - if (w, h) == (new_w, new_h): - return img - else: - return img.resize((new_w, new_h), interpolation) - else: - if max_size is not None: - raise ValueError( - "max_size should only be passed if size specifies the length of the smaller edge, " - "i.e. size should be an int or a sequence of length 1 in torchscript mode." - ) - return img.resize(size[::-1], interpolation) + return img.resize(size[::-1], interpolation) @torch.jit.unused diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 1899caebfc3..acc8d3ae3e1 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -430,70 +430,25 @@ def resize( img: Tensor, size: List[int], interpolation: str = "bilinear", - max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> Tensor: _assert_image_tensor(img) - if not isinstance(size, (int, tuple, list)): - raise TypeError("Got inappropriate size arg") - if not isinstance(interpolation, str): - raise TypeError("Got inappropriate interpolation arg") - - if interpolation not in ["nearest", "bilinear", "bicubic"]: - raise ValueError("This interpolation mode is unsupported with Tensor input") - if isinstance(size, tuple): size = list(size) - if isinstance(size, list): - if len(size) not in [1, 2]: - raise ValueError( - f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list" - ) - if max_size is not None and len(size) != 1: - raise ValueError( - "max_size should only be passed if size specifies the length of the smaller edge, " - "i.e. size should be an int or a sequence of length 1 in torchscript mode." - ) - if antialias is None: antialias = False if antialias and interpolation not in ["bilinear", "bicubic"]: raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") - _, h, w = get_dimensions(img) - - if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge - short, long = (w, h) if w <= h else (h, w) - requested_new_short = size if isinstance(size, int) else size[0] - - new_short, new_long = requested_new_short, int(requested_new_short * long / short) - - if max_size is not None: - if max_size <= requested_new_short: - raise ValueError( - f"max_size = {max_size} must be strictly greater than the requested " - f"size for the smaller edge size = {size}" - ) - if new_long > max_size: - new_short, new_long = int(max_size * new_short / new_long), max_size - - new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) - - if (w, h) == (new_w, new_h): - return img - - else: # specified both h and w - new_w, new_h = size[1], size[0] - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) # Define align_corners to avoid warnings align_corners = False if interpolation in ["bilinear", "bicubic"] else None - img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners, antialias=antialias) + img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias) if interpolation == "bicubic" and out_dtype == torch.uint8: img = img.clamp(min=0, max=255) From aade78f8a7bf36dbe70aaca7afcd2abf546d3ccb Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 16:07:31 +0000 Subject: [PATCH 4/9] Fixed failures --- torchvision/prototype/transforms/functional/_geometry.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ac0e8e0eb13..f1d51fded82 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -42,6 +42,8 @@ def resize_image_tensor( max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> torch.Tensor: + # TODO: use _compute_output_size to enable max_size option + max_size # ununsed right now new_height, new_width = size num_channels, old_height, old_width = get_dimensions_image_tensor(image) batch_shape = image.shape[:-3] @@ -49,7 +51,6 @@ def resize_image_tensor( image.reshape((-1, num_channels, old_height, old_width)), size=size, interpolation=interpolation.value, - max_size=max_size, antialias=antialias, ).reshape(batch_shape + (num_channels, new_height, new_width)) @@ -60,7 +61,9 @@ def resize_image_pil( interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, ) -> PIL.Image.Image: - return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation], max_size=max_size) + # TODO: use _compute_output_size to enable max_size option + max_size # ununsed right now + return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation]) def resize_segmentation_mask( From a812a3bcdcca6ca8d7af79330220f1344ae89aa9 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 20:57:14 +0000 Subject: [PATCH 5/9] More updates --- torchvision/transforms/functional.py | 12 ++++++------ torchvision/transforms/functional_pil.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 609c64ad4ff..77feadc51f1 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional +from typing import List, Tuple, Any, Optional, Union import numpy as np import torch @@ -360,10 +360,8 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace) -def _compute_output_size( - image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None -) -> Tuple[int, int]: - if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge +def _compute_output_size(image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None) -> List[int]: + if len(size) == 1: # specified size only for the smallest edge h, w = image_size short, long = (w, h) if w <= h else (h, w) requested_new_short = size if isinstance(size, int) else size[0] @@ -382,7 +380,7 @@ def _compute_output_size( new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) else: # specified both h and w new_w, new_h = size[1], size[0] - return new_h, new_w + return [new_h, new_w] def resize( @@ -460,6 +458,8 @@ def resize( ) _, image_height, image_width = get_dimensions(img) + if isinstance(size, int): + size = [size] output_size = _compute_output_size((image_height, image_width), size, max_size) if (image_height, image_width) == output_size: diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 3c1a911a5d4..7ebd9f71588 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -240,13 +240,13 @@ def crop( @torch.jit.unused def resize( img: Image.Image, - size: Union[Sequence[int], int], + size: Union[List[int], int], interpolation: int = _pil_constants.BILINEAR, ) -> Image.Image: if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") - if not (isinstance(size, Sequence) and len(size) == 2): + if not (isinstance(size, list) and len(size) == 2): raise TypeError(f"Got inappropriate size arg: {size}") return img.resize(size[::-1], interpolation) From 6661d8d948a0535cc22e30fe84a81ffac8d48661 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 21:01:36 +0000 Subject: [PATCH 6/9] Updated proto functional op: resize_image_* --- .../prototype/transforms/functional/_geometry.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 2aaed3e4a2e..36015b5c25d 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -6,7 +6,12 @@ import torch from torchvision.prototype import features from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP -from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix, InterpolationMode, _compute_output_size +from torchvision.transforms.functional import ( + pil_modes_mapping, + _get_inverse_affine_matrix, + InterpolationMode, + _compute_output_size, +) from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil @@ -43,7 +48,8 @@ def resize_image_tensor( antialias: Optional[bool] = None, ) -> torch.Tensor: num_channels, old_height, old_width = get_dimensions_image_tensor(image) - new_height, new_width = _compute_output_size((old_height, old_width), size=size, max_size=max_size) + size = _compute_output_size((old_height, old_width), size=size, max_size=max_size) + new_height, new_width = size batch_shape = image.shape[:-3] return _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), @@ -59,6 +65,10 @@ def resize_image_pil( interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, ) -> PIL.Image.Image: + if isinstance(size, int): + size = [size, size] + # Explicitly cast size to list otherwise mypy issue: incompatible type "Sequence[int]"; expected "List[int]" + size: List[int] = list(size) size = _compute_output_size(img.size[::-1], size=size, max_size=max_size) return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation]) From f0c896ff1391dcac098539db79f14e2c50549d7a Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 21:27:44 +0000 Subject: [PATCH 7/9] Added max_size arg to resize_bounding_box and updated basic tests --- test/test_prototype_transforms_functional.py | 23 ++++++++++++++++++- .../transforms/functional/_geometry.py | 6 +++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 5550c99158a..30d9b833ec8 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -223,15 +223,36 @@ def resize_image_tensor(): @register_kernel_info_from_sample_inputs_fn def resize_bounding_box(): - for bounding_box in make_bounding_boxes(): + for bounding_box, max_size in itertools.product( + make_bounding_boxes(), + [None, 34], # max_size + ): height, width = bounding_box.image_size for size in [ (height, width), (int(height * 0.75), int(width * 1.25)), ]: + if max_size is not None: + size = [size[0]] yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size) +@register_kernel_info_from_sample_inputs_fn +def resize_segmentation_mask(): + for mask, max_size in itertools.product( + make_segmentation_masks(), + [None, 34], # max_size + ): + height, width = mask.shape[-2:] + for size in [ + (height, width), + (int(height * 0.75), int(width * 1.25)), + ]: + if max_size is not None: + size = [size[0]] + yield SampleInput(mask, size=size, max_size=max_size) + + @register_kernel_info_from_sample_inputs_fn def affine_image_tensor(): for image, angle, translate, scale, shear in itertools.product( diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 36015b5c25d..19085d2a974 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -79,9 +79,11 @@ def resize_segmentation_mask( return resize_image_tensor(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size) -# TODO: handle max_size -def resize_bounding_box(bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor: +def resize_bounding_box( + bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int], max_size: Optional[int] = None +) -> torch.Tensor: old_height, old_width = image_size + size = _compute_output_size(image_size, size=size, max_size=max_size) new_height, new_width = size ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) From 6a5e5ab19de2a020a5797f2e3fd5606b78a1ae3d Mon Sep 17 00:00:00 2001 From: vfdev Date: Thu, 23 Jun 2022 12:16:51 +0200 Subject: [PATCH 8/9] Update functional.py --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 691eff84426..2a4a7f1b6dd 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional, Union +from typing import List, Tuple, Any, Optional import numpy as np import torch From b2ada459b27b8e9875c0305cf970cf3fbb76096b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 23 Jun 2022 10:36:08 +0000 Subject: [PATCH 9/9] Reverted fill/center order for rotate Other nits --- .../prototype/transforms/functional/_geometry.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 19085d2a974..95e094ad798 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -48,12 +48,11 @@ def resize_image_tensor( antialias: Optional[bool] = None, ) -> torch.Tensor: num_channels, old_height, old_width = get_dimensions_image_tensor(image) - size = _compute_output_size((old_height, old_width), size=size, max_size=max_size) - new_height, new_width = size + new_height, new_width = _compute_output_size((old_height, old_width), size=size, max_size=max_size) batch_shape = image.shape[:-3] return _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), - size=size, + size=[new_height, new_width], interpolation=interpolation.value, antialias=antialias, ).reshape(batch_shape + (num_channels, new_height, new_width)) @@ -83,8 +82,7 @@ def resize_bounding_box( bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int], max_size: Optional[int] = None ) -> torch.Tensor: old_height, old_width = image_size - size = _compute_output_size(image_size, size=size, max_size=max_size) - new_height, new_width = size + new_height, new_width = _compute_output_size(image_size, size=size, max_size=max_size) ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) @@ -330,8 +328,8 @@ def rotate_image_tensor( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - center: Optional[List[float]] = None, fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, ) -> torch.Tensor: center_f = [0.0, 0.0] if center is not None: