Skip to content

Commit 0e128f3

Browse files
committed
Add support of devices on tests.
1 parent 544967e commit 0e128f3

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

references/classification/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
import torch.utils.data
7+
from torch.utils.data.dataloader import default_collate
78
from torch import nn
89
import torchvision
910
from torchvision.transforms.functional import InterpolationMode
@@ -170,7 +171,7 @@ def main(args):
170171
if args.mixup_alpha > 0.0 or args.cutmix_alpha > 0.0:
171172
mixupcutmix = torchvision.transforms.RandomMixupCutmix(len(dataset.classes), mixup_alpha=args.mixup_alpha,
172173
cutmix_alpha=args.cutmix_alpha)
173-
collate_fn = lambda batch: mixupcutmix(*torch.utils.data._utils.collate.default_collate(batch)) # noqa: E731
174+
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731
174175
data_loader = torch.utils.data.DataLoader(
175176
dataset, batch_size=args.batch_size,
176177
sampler=train_sampler, num_workers=args.workers, pin_memory=True,

test/test_transforms_tensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,8 @@ def test_random_mixupcutmix_with_invalid_data():
764764
t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, ), dtype=torch.int32))
765765

766766

767-
def test_random_mixupcutmix_with_real_data():
767+
@pytest.mark.parametrize('device', cpu_and_gpu())
768+
def test_random_mixupcutmix_with_real_data(device):
768769
torch.manual_seed(12)
769770

770771
# Build dummy dataset
@@ -773,7 +774,8 @@ def test_random_mixupcutmix_with_real_data():
773774
fullpath = (os.path.dirname(os.path.abspath(__file__)), 'assets') + test_file
774775
img = read_image(get_file_path_2(*fullpath))
775776
images.append(F.resize(img, [224, 224]))
776-
dataset = TensorDataset(torch.stack(images).to(torch.float32), torch.tensor([0, 1]))
777+
dataset = TensorDataset(torch.stack(images).to(device=device, dtype=torch.float32),
778+
torch.tensor([0, 1], device=device))
777779

778780
# Use mixup in the collate
779781
mixup = T.RandomMixupCutmix(2, cutmix_alpha=1.0, mixup_alpha=1.0)

0 commit comments

Comments
 (0)