diff --git a/test/assets/fakedata/draw_boxes_util.png b/test/assets/fakedata/draw_boxes_util.png index d64fa2f1f36..2c361c5fafd 100644 Binary files a/test/assets/fakedata/draw_boxes_util.png and b/test/assets/fakedata/draw_boxes_util.png differ diff --git a/test/test_utils.py b/test/test_utils.py index b9893cdd1ac..8c4cc620229 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,7 +7,10 @@ import unittest from io import BytesIO import torchvision.transforms.functional as F -from PIL import Image +from PIL import Image, __version__ as PILLOW_VERSION + + +PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) @@ -120,8 +123,11 @@ def test_draw_boxes(self): res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy()) res.save(path) - expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) - self.assertTrue(torch.equal(result, expected)) + if PILLOW_VERSION >= (8, 2): + # The reference image is only valid for new PIL versions + expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1) + self.assertTrue(torch.equal(result, expected)) + # Check if modification is not in place self.assertTrue(torch.all(torch.eq(boxes, boxes_cp)).item()) self.assertTrue(torch.all(torch.eq(img, img_cp)).item())