Skip to content

Commit 4c049ca

Browse files
authored
replace new_like with wrap_like (#6718)
* replace new_like with wrap_like * fix videos * revert casting in favor of ignoring mypy
1 parent 3118fb5 commit 4c049ca

18 files changed

+239
-196
lines changed

test/test_prototype_features.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ def test_inplace_op_no_wrapping():
9999
assert type(label) is features.Label
100100

101101

102-
def test_new_like():
102+
def test_wrap_like():
103103
tensor = torch.tensor([0, 1, 0], dtype=torch.int64)
104104
label = features.Label(tensor, categories=["foo", "bar"])
105105

106106
# any operation besides .to() and .clone() will do here
107107
output = label * 2
108108

109-
label_new = features.Label.new_like(label, output)
109+
label_new = features.Label.wrap_like(label, output)
110110

111111
assert type(label_new) is features.Label
112112
assert label_new.data_ptr() == output.data_ptr()

test/test_prototype_transforms.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from common_utils import assert_equal, cpu_and_gpu
1010
from prototype_common_utils import (
11+
DEFAULT_EXTRA_DIMS,
1112
make_bounding_box,
1213
make_bounding_boxes,
1314
make_detection_mask,
@@ -23,6 +24,8 @@
2324
from torchvision.prototype import features, transforms
2425
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
2526

27+
BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
28+
2629

2730
def make_vanilla_tensor_images(*args, **kwargs):
2831
for image in make_images(*args, **kwargs):
@@ -109,13 +112,11 @@ def test_common(self, transform, input):
109112
(
110113
transform,
111114
[
112-
dict(
113-
image=features.Image.new_like(image, image.unsqueeze(0), dtype=torch.float),
114-
one_hot_label=features.OneHotLabel.new_like(
115-
one_hot_label, one_hot_label.unsqueeze(0), dtype=torch.float
116-
),
115+
dict(image=image, one_hot_label=one_hot_label)
116+
for image, one_hot_label in itertools.product(
117+
make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
118+
make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
117119
)
118-
for image, one_hot_label in itertools.product(make_images(), make_one_hot_labels())
119120
],
120121
)
121122
for transform in [
@@ -300,7 +301,7 @@ def test_features_bounding_box(self, p):
300301
actual = transform(input)
301302

302303
expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input
303-
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
304+
expected = features.BoundingBox.wrap_like(input, expected_image_tensor)
304305
assert_equal(expected, actual)
305306
assert actual.format == expected.format
306307
assert actual.image_size == expected.image_size
@@ -353,7 +354,7 @@ def test_features_bounding_box(self, p):
353354
actual = transform(input)
354355

355356
expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input
356-
expected = features.BoundingBox.new_like(input, data=expected_image_tensor)
357+
expected = features.BoundingBox.wrap_like(input, expected_image_tensor)
357358
assert_equal(expected, actual)
358359
assert actual.format == expected.format
359360
assert actual.image_size == expected.image_size

torchvision/prototype/features/_bounding_box.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ class BoundingBox(_Feature):
1919
format: BoundingBoxFormat
2020
image_size: Tuple[int, int]
2121

22+
@classmethod
23+
def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, image_size: Tuple[int, int]) -> BoundingBox:
24+
bounding_box = tensor.as_subclass(cls)
25+
bounding_box.format = format
26+
bounding_box.image_size = image_size
27+
return bounding_box
28+
2229
def __new__(
2330
cls,
2431
data: Any,
@@ -29,52 +36,46 @@ def __new__(
2936
device: Optional[Union[torch.device, str, int]] = None,
3037
requires_grad: bool = False,
3138
) -> BoundingBox:
32-
bounding_box = super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
39+
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
3340

3441
if isinstance(format, str):
3542
format = BoundingBoxFormat.from_str(format.upper())
36-
bounding_box.format = format
37-
38-
bounding_box.image_size = image_size
3943

40-
return bounding_box
41-
42-
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
43-
return self._make_repr(format=self.format, image_size=self.image_size)
44+
return cls._wrap(tensor, format=format, image_size=image_size)
4445

4546
@classmethod
46-
def new_like(
47+
def wrap_like(
4748
cls,
4849
other: BoundingBox,
49-
data: Any,
50+
tensor: torch.Tensor,
5051
*,
51-
format: Optional[Union[BoundingBoxFormat, str]] = None,
52+
format: Optional[BoundingBoxFormat] = None,
5253
image_size: Optional[Tuple[int, int]] = None,
53-
**kwargs: Any,
5454
) -> BoundingBox:
55-
return super().new_like(
56-
other,
57-
data,
55+
return cls._wrap(
56+
tensor,
5857
format=format if format is not None else other.format,
5958
image_size=image_size if image_size is not None else other.image_size,
60-
**kwargs,
6159
)
6260

61+
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
62+
return self._make_repr(format=self.format, image_size=self.image_size)
63+
6364
def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
6465
if isinstance(format, str):
6566
format = BoundingBoxFormat.from_str(format.upper())
6667

67-
return BoundingBox.new_like(
68+
return BoundingBox.wrap_like(
6869
self, self._F.convert_format_bounding_box(self, old_format=self.format, new_format=format), format=format
6970
)
7071

7172
def horizontal_flip(self) -> BoundingBox:
7273
output = self._F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size)
73-
return BoundingBox.new_like(self, output)
74+
return BoundingBox.wrap_like(self, output)
7475

7576
def vertical_flip(self) -> BoundingBox:
7677
output = self._F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size)
77-
return BoundingBox.new_like(self, output)
78+
return BoundingBox.wrap_like(self, output)
7879

7980
def resize( # type: ignore[override]
8081
self,
@@ -84,19 +85,19 @@ def resize( # type: ignore[override]
8485
antialias: bool = False,
8586
) -> BoundingBox:
8687
output, image_size = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size)
87-
return BoundingBox.new_like(self, output, image_size=image_size)
88+
return BoundingBox.wrap_like(self, output, image_size=image_size)
8889

8990
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
9091
output, image_size = self._F.crop_bounding_box(
9192
self, self.format, top=top, left=left, height=height, width=width
9293
)
93-
return BoundingBox.new_like(self, output, image_size=image_size)
94+
return BoundingBox.wrap_like(self, output, image_size=image_size)
9495

9596
def center_crop(self, output_size: List[int]) -> BoundingBox:
9697
output, image_size = self._F.center_crop_bounding_box(
9798
self, format=self.format, image_size=self.image_size, output_size=output_size
9899
)
99-
return BoundingBox.new_like(self, output, image_size=image_size)
100+
return BoundingBox.wrap_like(self, output, image_size=image_size)
100101

101102
def resized_crop(
102103
self,
@@ -109,7 +110,7 @@ def resized_crop(
109110
antialias: bool = False,
110111
) -> BoundingBox:
111112
output, image_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
112-
return BoundingBox.new_like(self, output, image_size=image_size)
113+
return BoundingBox.wrap_like(self, output, image_size=image_size)
113114

114115
def pad(
115116
self,
@@ -120,7 +121,7 @@ def pad(
120121
output, image_size = self._F.pad_bounding_box(
121122
self, format=self.format, image_size=self.image_size, padding=padding, padding_mode=padding_mode
122123
)
123-
return BoundingBox.new_like(self, output, image_size=image_size)
124+
return BoundingBox.wrap_like(self, output, image_size=image_size)
124125

125126
def rotate(
126127
self,
@@ -133,7 +134,7 @@ def rotate(
133134
output, image_size = self._F.rotate_bounding_box(
134135
self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center
135136
)
136-
return BoundingBox.new_like(self, output, image_size=image_size)
137+
return BoundingBox.wrap_like(self, output, image_size=image_size)
137138

138139
def affine(
139140
self,
@@ -155,7 +156,7 @@ def affine(
155156
shear=shear,
156157
center=center,
157158
)
158-
return BoundingBox.new_like(self, output, dtype=output.dtype)
159+
return BoundingBox.wrap_like(self, output)
159160

160161
def perspective(
161162
self,
@@ -164,7 +165,7 @@ def perspective(
164165
fill: FillTypeJIT = None,
165166
) -> BoundingBox:
166167
output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs)
167-
return BoundingBox.new_like(self, output, dtype=output.dtype)
168+
return BoundingBox.wrap_like(self, output)
168169

169170
def elastic(
170171
self,
@@ -173,4 +174,4 @@ def elastic(
173174
fill: FillTypeJIT = None,
174175
) -> BoundingBox:
175176
output = self._F.elastic_bounding_box(self, self.format, displacement)
176-
return BoundingBox.new_like(self, output, dtype=output.dtype)
177+
return BoundingBox.wrap_like(self, output)

torchvision/prototype/features/_encoded.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515

1616
class EncodedData(_Feature):
17+
@classmethod
18+
def _wrap(cls: Type[D], tensor: torch.Tensor) -> D:
19+
return tensor.as_subclass(cls)
20+
1721
def __new__(
1822
cls,
1923
data: Any,
@@ -22,8 +26,13 @@ def __new__(
2226
device: Optional[Union[torch.device, str, int]] = None,
2327
requires_grad: bool = False,
2428
) -> EncodedData:
29+
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
2530
# TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8?
26-
return super().__new__(cls, data, dtype=dtype, device=device, requires_grad=requires_grad)
31+
return cls._wrap(tensor)
32+
33+
@classmethod
34+
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
35+
return cls._wrap(tensor)
2736

2837
@classmethod
2938
def from_file(cls: Type[D], file: BinaryIO, **kwargs: Any) -> D:

torchvision/prototype/features/_feature.py

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,48 +21,39 @@ def is_simple_tensor(inpt: Any) -> bool:
2121
class _Feature(torch.Tensor):
2222
__F: Optional[ModuleType] = None
2323

24-
def __new__(
25-
cls: Type[F],
24+
@staticmethod
25+
def _to_tensor(
2626
data: Any,
27-
*,
2827
dtype: Optional[torch.dtype] = None,
2928
device: Optional[Union[torch.device, str, int]] = None,
3029
requires_grad: bool = False,
31-
) -> F:
32-
return (
33-
torch.as_tensor( # type: ignore[return-value]
34-
data,
35-
dtype=dtype,
36-
device=device,
37-
)
38-
.as_subclass(cls)
39-
.requires_grad_(requires_grad)
40-
)
30+
) -> torch.Tensor:
31+
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)
4132

42-
@classmethod
43-
def new_like(
44-
cls: Type[F],
45-
other: F,
33+
# FIXME: this is just here for BC with the prototype datasets. Some datasets use the _Feature directly to have a
34+
# a no-op input for the prototype transforms. For this use case, we can't use plain tensors, since they will be
35+
# interpreted as images. We should decide if we want a public no-op feature like `GenericFeature` or make this one
36+
# public again.
37+
def __new__(
38+
cls,
4639
data: Any,
47-
*,
4840
dtype: Optional[torch.dtype] = None,
4941
device: Optional[Union[torch.device, str, int]] = None,
50-
requires_grad: Optional[bool] = None,
51-
**kwargs: Any,
52-
) -> F:
53-
return cls(
54-
data,
55-
dtype=dtype if dtype is not None else other.dtype,
56-
device=device if device is not None else other.device,
57-
requires_grad=requires_grad if requires_grad is not None else other.requires_grad,
58-
**kwargs,
59-
)
42+
requires_grad: bool = False,
43+
) -> _Feature:
44+
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
45+
return tensor.as_subclass(_Feature)
46+
47+
@classmethod
48+
def wrap_like(cls: Type[F], other: F, tensor: torch.Tensor) -> F:
49+
# FIXME: this is just here for BC with the prototype datasets. See __new__ for details. If that is resolved,
50+
# this method should be made abstract
51+
# raise NotImplementedError
52+
return tensor.as_subclass(cls)
6053

6154
_NO_WRAPPING_EXCEPTIONS = {
62-
torch.Tensor.clone: lambda cls, input, output: cls.new_like(input, output),
63-
torch.Tensor.to: lambda cls, input, output: cls.new_like(
64-
input, output, dtype=output.dtype, device=output.device
65-
),
55+
torch.Tensor.clone: lambda cls, input, output: cls.wrap_like(input, output),
56+
torch.Tensor.to: lambda cls, input, output: cls.wrap_like(input, output),
6657
# We don't need to wrap the output of `Tensor.requires_grad_`, since it is an inplace operation and thus
6758
# retains the type automatically
6859
torch.Tensor.requires_grad_: lambda cls, input, output: output,

0 commit comments

Comments
 (0)