@@ -19,6 +19,13 @@ class BoundingBox(_Feature):
19
19
format : BoundingBoxFormat
20
20
image_size : Tuple [int , int ]
21
21
22
+ @classmethod
23
+ def _wrap (cls , tensor : torch .Tensor , * , format : BoundingBoxFormat , image_size : Tuple [int , int ]) -> BoundingBox :
24
+ bounding_box = tensor .as_subclass (cls )
25
+ bounding_box .format = format
26
+ bounding_box .image_size = image_size
27
+ return bounding_box
28
+
22
29
def __new__ (
23
30
cls ,
24
31
data : Any ,
@@ -29,52 +36,46 @@ def __new__(
29
36
device : Optional [Union [torch .device , str , int ]] = None ,
30
37
requires_grad : bool = False ,
31
38
) -> BoundingBox :
32
- bounding_box = super (). __new__ ( cls , data , dtype = dtype , device = device , requires_grad = requires_grad )
39
+ tensor = cls . _to_tensor ( data , dtype = dtype , device = device , requires_grad = requires_grad )
33
40
34
41
if isinstance (format , str ):
35
42
format = BoundingBoxFormat .from_str (format .upper ())
36
- bounding_box .format = format
37
-
38
- bounding_box .image_size = image_size
39
43
40
- return bounding_box
41
-
42
- def __repr__ (self , * , tensor_contents : Any = None ) -> str : # type: ignore[override]
43
- return self ._make_repr (format = self .format , image_size = self .image_size )
44
+ return cls ._wrap (tensor , format = format , image_size = image_size )
44
45
45
46
@classmethod
46
- def new_like (
47
+ def wrap_like (
47
48
cls ,
48
49
other : BoundingBox ,
49
- data : Any ,
50
+ tensor : torch . Tensor ,
50
51
* ,
51
- format : Optional [Union [ BoundingBoxFormat , str ] ] = None ,
52
+ format : Optional [BoundingBoxFormat ] = None ,
52
53
image_size : Optional [Tuple [int , int ]] = None ,
53
- ** kwargs : Any ,
54
54
) -> BoundingBox :
55
- return super ().new_like (
56
- other ,
57
- data ,
55
+ return cls ._wrap (
56
+ tensor ,
58
57
format = format if format is not None else other .format ,
59
58
image_size = image_size if image_size is not None else other .image_size ,
60
- ** kwargs ,
61
59
)
62
60
61
+ def __repr__ (self , * , tensor_contents : Any = None ) -> str : # type: ignore[override]
62
+ return self ._make_repr (format = self .format , image_size = self .image_size )
63
+
63
64
def to_format (self , format : Union [str , BoundingBoxFormat ]) -> BoundingBox :
64
65
if isinstance (format , str ):
65
66
format = BoundingBoxFormat .from_str (format .upper ())
66
67
67
- return BoundingBox .new_like (
68
+ return BoundingBox .wrap_like (
68
69
self , self ._F .convert_format_bounding_box (self , old_format = self .format , new_format = format ), format = format
69
70
)
70
71
71
72
def horizontal_flip (self ) -> BoundingBox :
72
73
output = self ._F .horizontal_flip_bounding_box (self , format = self .format , image_size = self .image_size )
73
- return BoundingBox .new_like (self , output )
74
+ return BoundingBox .wrap_like (self , output )
74
75
75
76
def vertical_flip (self ) -> BoundingBox :
76
77
output = self ._F .vertical_flip_bounding_box (self , format = self .format , image_size = self .image_size )
77
- return BoundingBox .new_like (self , output )
78
+ return BoundingBox .wrap_like (self , output )
78
79
79
80
def resize ( # type: ignore[override]
80
81
self ,
@@ -84,19 +85,19 @@ def resize( # type: ignore[override]
84
85
antialias : bool = False ,
85
86
) -> BoundingBox :
86
87
output , image_size = self ._F .resize_bounding_box (self , image_size = self .image_size , size = size , max_size = max_size )
87
- return BoundingBox .new_like (self , output , image_size = image_size )
88
+ return BoundingBox .wrap_like (self , output , image_size = image_size )
88
89
89
90
def crop (self , top : int , left : int , height : int , width : int ) -> BoundingBox :
90
91
output , image_size = self ._F .crop_bounding_box (
91
92
self , self .format , top = top , left = left , height = height , width = width
92
93
)
93
- return BoundingBox .new_like (self , output , image_size = image_size )
94
+ return BoundingBox .wrap_like (self , output , image_size = image_size )
94
95
95
96
def center_crop (self , output_size : List [int ]) -> BoundingBox :
96
97
output , image_size = self ._F .center_crop_bounding_box (
97
98
self , format = self .format , image_size = self .image_size , output_size = output_size
98
99
)
99
- return BoundingBox .new_like (self , output , image_size = image_size )
100
+ return BoundingBox .wrap_like (self , output , image_size = image_size )
100
101
101
102
def resized_crop (
102
103
self ,
@@ -109,7 +110,7 @@ def resized_crop(
109
110
antialias : bool = False ,
110
111
) -> BoundingBox :
111
112
output , image_size = self ._F .resized_crop_bounding_box (self , self .format , top , left , height , width , size = size )
112
- return BoundingBox .new_like (self , output , image_size = image_size )
113
+ return BoundingBox .wrap_like (self , output , image_size = image_size )
113
114
114
115
def pad (
115
116
self ,
@@ -120,7 +121,7 @@ def pad(
120
121
output , image_size = self ._F .pad_bounding_box (
121
122
self , format = self .format , image_size = self .image_size , padding = padding , padding_mode = padding_mode
122
123
)
123
- return BoundingBox .new_like (self , output , image_size = image_size )
124
+ return BoundingBox .wrap_like (self , output , image_size = image_size )
124
125
125
126
def rotate (
126
127
self ,
@@ -133,7 +134,7 @@ def rotate(
133
134
output , image_size = self ._F .rotate_bounding_box (
134
135
self , format = self .format , image_size = self .image_size , angle = angle , expand = expand , center = center
135
136
)
136
- return BoundingBox .new_like (self , output , image_size = image_size )
137
+ return BoundingBox .wrap_like (self , output , image_size = image_size )
137
138
138
139
def affine (
139
140
self ,
@@ -155,7 +156,7 @@ def affine(
155
156
shear = shear ,
156
157
center = center ,
157
158
)
158
- return BoundingBox .new_like (self , output , dtype = output . dtype )
159
+ return BoundingBox .wrap_like (self , output )
159
160
160
161
def perspective (
161
162
self ,
@@ -164,7 +165,7 @@ def perspective(
164
165
fill : FillTypeJIT = None ,
165
166
) -> BoundingBox :
166
167
output = self ._F .perspective_bounding_box (self , self .format , perspective_coeffs )
167
- return BoundingBox .new_like (self , output , dtype = output . dtype )
168
+ return BoundingBox .wrap_like (self , output )
168
169
169
170
def elastic (
170
171
self ,
@@ -173,4 +174,4 @@ def elastic(
173
174
fill : FillTypeJIT = None ,
174
175
) -> BoundingBox :
175
176
output = self ._F .elastic_bounding_box (self , self .format , displacement )
176
- return BoundingBox .new_like (self , output , dtype = output . dtype )
177
+ return BoundingBox .wrap_like (self , output )
0 commit comments