Skip to content

Commit 37a0d8d

Browse files
authored
[BC-breaking] Fix for integer fill value in constant padding (#2284)
* Bugfix in pad * Address review comments * Fix lint
1 parent 3902140 commit 37a0d8d

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

test/test_transforms.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,13 +299,22 @@ def test_pad(self):
299299
width = random.randint(10, 32) * 2
300300
img = torch.ones(3, height, width)
301301
padding = random.randint(1, 20)
302+
fill = random.randint(1, 50)
302303
result = transforms.Compose([
303304
transforms.ToPILImage(),
304-
transforms.Pad(padding),
305+
transforms.Pad(padding, fill=fill),
305306
transforms.ToTensor(),
306307
])(img)
307308
self.assertEqual(result.size(1), height + 2 * padding)
308309
self.assertEqual(result.size(2), width + 2 * padding)
310+
# check that all elements in the padded region correspond
311+
# to the pad value
312+
fill_v = fill / 255
313+
eps = 1e-5
314+
self.assertTrue((result[:, :padding, :] - fill_v).abs().max() < eps)
315+
self.assertTrue((result[:, :, :padding] - fill_v).abs().max() < eps)
316+
self.assertRaises(ValueError, transforms.Pad(padding, fill=(1, 2)),
317+
transforms.ToPILImage()(img))
309318

310319
def test_pad_with_tuple_of_pad_values(self):
311320
height = random.randint(10, 32) * 2

torchvision/transforms/functional.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,12 @@ def pad(img, padding, fill=0, padding_mode='constant'):
329329
'Padding mode should be either constant, edge, reflect or symmetric'
330330

331331
if padding_mode == 'constant':
332+
if isinstance(fill, numbers.Number):
333+
fill = (fill,) * len(img.getbands())
334+
if len(fill) != len(img.getbands()):
335+
raise ValueError('fill should have the same number of elements '
336+
'as the number of channels in the image '
337+
'({}), got {} instead'.format(len(img.getbands()), len(fill)))
332338
if img.mode == 'P':
333339
palette = img.getpalette()
334340
image = ImageOps.expand(img, border=padding, fill=fill)

0 commit comments

Comments
 (0)