Skip to content

Commit 93fc321

Browse files
committed
add xfails for arbitrary batch sizes on some kernels
1 parent f1e2bfa commit 93fc321

File tree

2 files changed

+70
-32
lines changed

2 files changed

+70
-32
lines changed

test/prototype_transforms_dispatcher_infos.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,23 @@ def fill_sequence_needs_broadcast(args_kwargs):
127127
)
128128

129129

130+
def xfail_all_tests(*, reason, condition):
131+
return [
132+
TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
133+
for test_name in [
134+
"test_scripted_smoke",
135+
"test_dispatch_simple_tensor",
136+
"test_dispatch_feature",
137+
]
138+
]
139+
140+
141+
xfails_degenerate_or_multi_batch_dims = xfail_all_tests(
142+
reason="See https://github.com/pytorch/vision/issues/6670 for details.",
143+
condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]),
144+
)
145+
146+
130147
DISPATCHER_INFOS = [
131148
DispatcherInfo(
132149
F.horizontal_flip,
@@ -243,6 +260,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
243260
pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
244261
test_marks=[
245262
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
263+
*xfails_degenerate_or_multi_batch_dims,
246264
],
247265
),
248266
DispatcherInfo(
@@ -253,6 +271,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
253271
features.Mask: F.elastic_mask,
254272
},
255273
pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
274+
test_marks=xfails_degenerate_or_multi_batch_dims,
256275
),
257276
DispatcherInfo(
258277
F.center_crop,
@@ -275,6 +294,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
275294
test_marks=[
276295
xfail_jit_python_scalar_arg("kernel_size"),
277296
xfail_jit_python_scalar_arg("sigma"),
297+
*xfails_degenerate_or_multi_batch_dims,
278298
],
279299
),
280300
DispatcherInfo(
@@ -283,6 +303,7 @@ def fill_sequence_needs_broadcast(args_kwargs):
283303
features.Image: F.equalize_image_tensor,
284304
},
285305
pil_kernel_info=PILKernelInfo(F.equalize_image_pil, kernel_name="equalize_image_pil"),
306+
test_marks=xfails_degenerate_or_multi_batch_dims,
286307
),
287308
DispatcherInfo(
288309
F.invert,
@@ -318,6 +339,15 @@ def fill_sequence_needs_broadcast(args_kwargs):
318339
features.Image: F.adjust_sharpness_image_tensor,
319340
},
320341
pil_kernel_info=PILKernelInfo(F.adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"),
342+
test_marks=xfail_all_tests(
343+
reason="See https://github.com/pytorch/vision/issues/6670 for details.",
344+
condition=lambda args_kwargs: all(dim > 2 for dim in args_kwargs.args[0].shape[-2:])
345+
and (
346+
len(args_kwargs.args[0].shape) > 4
347+
or not all(args_kwargs.args[0].shape[:-4])
348+
or args_kwargs.args[0].shape[-4:-2] == (0, 3)
349+
),
350+
),
321351
),
322352
DispatcherInfo(
323353
F.erase,

test/prototype_transforms_kernel_infos.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,25 @@ def xfail_jit_list_of_ints(name, *, reason=None):
143143
)
144144

145145

146+
def xfail_all_tests(*, reason, condition):
147+
return [
148+
TestMark(("TestKernels", test_name), pytest.mark.xfail(reason=reason), condition=condition)
149+
for test_name in [
150+
"test_scripted_vs_eager",
151+
"test_batched_vs_single",
152+
"test_no_inplace",
153+
"test_cuda_vs_cpu",
154+
"test_dtype_and_device_consistency",
155+
]
156+
]
157+
158+
159+
xfails_image_degenerate_or_multi_batch_dims = xfail_all_tests(
160+
reason="See https://github.com/pytorch/vision/issues/6670 for details.",
161+
condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]),
162+
)
163+
164+
146165
KERNEL_INFOS = []
147166

148167

@@ -1093,11 +1112,7 @@ def sample_inputs_pad_video():
10931112

10941113

10951114
def sample_inputs_perspective_image_tensor():
1096-
for image_loader in make_image_loaders(
1097-
sizes=["random"],
1098-
# FIXME: kernel should support arbitrary batch sizes
1099-
extra_dims=[(), (4,)],
1100-
):
1115+
for image_loader in make_image_loaders(sizes=["random"]):
11011116
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]:
11021117
yield ArgsKwargs(image_loader, fill=fill, perspective_coeffs=_PERSPECTIVE_COEFFS[0])
11031118

@@ -1117,11 +1132,7 @@ def sample_inputs_perspective_bounding_box():
11171132

11181133

11191134
def sample_inputs_perspective_mask():
1120-
for mask_loader in make_mask_loaders(
1121-
sizes=["random"],
1122-
# FIXME: kernel should support arbitrary batch sizes
1123-
extra_dims=[(), (4,)],
1124-
):
1135+
for mask_loader in make_mask_loaders(sizes=["random"]):
11251136
yield ArgsKwargs(mask_loader, perspective_coeffs=_PERSPECTIVE_COEFFS[0])
11261137

11271138

@@ -1145,6 +1156,7 @@ def sample_inputs_perspective_video():
11451156
reference_fn=pil_reference_wrapper(F.perspective_image_pil),
11461157
reference_inputs_fn=reference_inputs_perspective_image_tensor,
11471158
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
1159+
test_marks=xfails_image_degenerate_or_multi_batch_dims,
11481160
),
11491161
KernelInfo(
11501162
F.perspective_bounding_box,
@@ -1156,6 +1168,7 @@ def sample_inputs_perspective_video():
11561168
reference_fn=pil_reference_wrapper(F.perspective_image_pil),
11571169
reference_inputs_fn=reference_inputs_perspective_mask,
11581170
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
1171+
test_marks=xfails_image_degenerate_or_multi_batch_dims,
11591172
),
11601173
KernelInfo(
11611174
F.perspective_video,
@@ -1170,11 +1183,7 @@ def _get_elastic_displacement(image_size):
11701183

11711184

11721185
def sample_inputs_elastic_image_tensor():
1173-
for image_loader in make_image_loaders(
1174-
sizes=["random"],
1175-
# FIXME: kernel should support arbitrary batch sizes
1176-
extra_dims=[(), (4,)],
1177-
):
1186+
for image_loader in make_image_loaders(sizes=["random"]):
11781187
displacement = _get_elastic_displacement(image_loader.image_size)
11791188
for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]:
11801189
yield ArgsKwargs(image_loader, displacement=displacement, fill=fill)
@@ -1205,11 +1214,7 @@ def sample_inputs_elastic_bounding_box():
12051214

12061215

12071216
def sample_inputs_elastic_mask():
1208-
for mask_loader in make_mask_loaders(
1209-
sizes=["random"],
1210-
# FIXME: kernel should support arbitrary batch sizes
1211-
extra_dims=[(), (4,)],
1212-
):
1217+
for mask_loader in make_mask_loaders(sizes=["random"]):
12131218
displacement = _get_elastic_displacement(mask_loader.shape[-2:])
12141219
yield ArgsKwargs(mask_loader, displacement=displacement)
12151220

@@ -1234,6 +1239,7 @@ def sample_inputs_elastic_video():
12341239
reference_fn=pil_reference_wrapper(F.elastic_image_pil),
12351240
reference_inputs_fn=reference_inputs_elastic_image_tensor,
12361241
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
1242+
test_marks=xfails_image_degenerate_or_multi_batch_dims,
12371243
),
12381244
KernelInfo(
12391245
F.elastic_bounding_box,
@@ -1245,6 +1251,7 @@ def sample_inputs_elastic_video():
12451251
reference_fn=pil_reference_wrapper(F.elastic_image_pil),
12461252
reference_inputs_fn=reference_inputs_elastic_mask,
12471253
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
1254+
test_marks=xfails_image_degenerate_or_multi_batch_dims,
12481255
),
12491256
KernelInfo(
12501257
F.elastic_video,
@@ -1346,11 +1353,7 @@ def sample_inputs_center_crop_video():
13461353

13471354
def sample_inputs_gaussian_blur_image_tensor():
13481355
make_gaussian_blur_image_loaders = functools.partial(
1349-
make_image_loaders,
1350-
sizes=["random"],
1351-
color_spaces=[features.ColorSpace.RGB],
1352-
# FIXME: kernel should support arbitrary batch sizes
1353-
extra_dims=[(), (4,)],
1356+
make_image_loaders, sizes=["random"], color_spaces=[features.ColorSpace.RGB]
13541357
)
13551358

13561359
for image_loader, kernel_size in itertools.product(make_gaussian_blur_image_loaders(), [5, (3, 3), [3, 3]]):
@@ -1376,6 +1379,7 @@ def sample_inputs_gaussian_blur_video():
13761379
test_marks=[
13771380
xfail_jit_python_scalar_arg("kernel_size"),
13781381
xfail_jit_python_scalar_arg("sigma"),
1382+
*xfails_image_degenerate_or_multi_batch_dims,
13791383
],
13801384
),
13811385
KernelInfo(
@@ -1388,11 +1392,7 @@ def sample_inputs_gaussian_blur_video():
13881392

13891393
def sample_inputs_equalize_image_tensor():
13901394
for image_loader in make_image_loaders(
1391-
sizes=["random"],
1392-
# FIXME: kernel should support arbitrary batch sizes
1393-
extra_dims=[(), (4,)],
1394-
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB),
1395-
dtypes=[torch.uint8],
1395+
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8]
13961396
):
13971397
yield ArgsKwargs(image_loader)
13981398

@@ -1418,6 +1418,7 @@ def sample_inputs_equalize_video():
14181418
reference_fn=pil_reference_wrapper(F.equalize_image_pil),
14191419
reference_inputs_fn=reference_inputs_equalize_image_tensor,
14201420
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
1421+
test_marks=xfails_image_degenerate_or_multi_batch_dims,
14211422
),
14221423
KernelInfo(
14231424
F.equalize_video,
@@ -1594,8 +1595,6 @@ def sample_inputs_adjust_sharpness_image_tensor():
15941595
for image_loader in make_image_loaders(
15951596
sizes=["random", (2, 2)],
15961597
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB),
1597-
# FIXME: kernel should support arbitrary batch sizes
1598-
extra_dims=[(), (4,)],
15991598
):
16001599
yield ArgsKwargs(image_loader, sharpness_factor=_ADJUST_SHARPNESS_FACTORS[0])
16011600

@@ -1622,6 +1621,15 @@ def sample_inputs_adjust_sharpness_video():
16221621
reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil),
16231622
reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor,
16241623
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
1624+
test_marks=xfail_all_tests(
1625+
reason="See https://github.com/pytorch/vision/issues/6670 for details.",
1626+
condition=lambda args_kwargs: all(dim > 2 for dim in args_kwargs.args[0].shape[-2:])
1627+
and (
1628+
len(args_kwargs.args[0].shape) > 4
1629+
or not all(args_kwargs.args[0].shape[:-4])
1630+
or args_kwargs.args[0].shape[-4:-2] == (0, 3)
1631+
),
1632+
),
16251633
),
16261634
KernelInfo(
16271635
F.adjust_sharpness_video,

0 commit comments

Comments
 (0)