|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | + |
| 4 | + |
| 5 | +@pytest.fixture(params=range(14)) |
| 6 | +def test_data_binary(request): |
| 7 | + return [ |
| 8 | + # Binary accuracy on input of shape (N, 1) or (N, ) |
| 9 | + (torch.randint(0, 2, size=(10,)), torch.randint(0, 2, size=(10,)), 1), |
| 10 | + (torch.randint(0, 2, size=(10, 1)), torch.randint(0, 2, size=(10, 1)), 1), |
| 11 | + # updated batches |
| 12 | + (torch.randint(0, 2, size=(50,)), torch.randint(0, 2, size=(50,)), 16), |
| 13 | + (torch.randint(0, 2, size=(50, 1)), torch.randint(0, 2, size=(50, 1)), 16), |
| 14 | + # Binary accuracy on input of shape (N, L) |
| 15 | + (torch.randint(0, 2, size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), |
| 16 | + (torch.randint(0, 2, size=(10, 1, 5)), torch.randint(0, 2, size=(10, 1, 5)), 1), |
| 17 | + # updated batches |
| 18 | + (torch.randint(0, 2, size=(50, 5)), torch.randint(0, 2, size=(50, 5)), 16), |
| 19 | + (torch.randint(0, 2, size=(50, 1, 5)), torch.randint(0, 2, size=(50, 1, 5)), 16), |
| 20 | + # Binary accuracy on input of shape (N, H, W) |
| 21 | + (torch.randint(0, 2, size=(10, 12, 10)), torch.randint(0, 2, size=(10, 12, 10)), 1), |
| 22 | + (torch.randint(0, 2, size=(10, 1, 12, 10)), torch.randint(0, 2, size=(10, 1, 12, 10)), 1), |
| 23 | + # updated batches |
| 24 | + (torch.randint(0, 2, size=(50, 12, 10)), torch.randint(0, 2, size=(50, 12, 10)), 16), |
| 25 | + (torch.randint(0, 2, size=(50, 1, 12, 10)), torch.randint(0, 2, size=(50, 1, 12, 10)), 16), |
| 26 | + # Corner case with all zeros predictions |
| 27 | + (torch.zeros(size=(10,), dtype=torch.long), torch.randint(0, 2, size=(10,)), 1), |
| 28 | + (torch.zeros(size=(10, 1), dtype=torch.long), torch.randint(0, 2, size=(10, 1)), 1), |
| 29 | + ][request.param] |
| 30 | + |
| 31 | + |
| 32 | +@pytest.fixture(params=range(14)) |
| 33 | +def test_data_multiclass(request): |
| 34 | + return [ |
| 35 | + # Multiclass input data of shape (N, ) and (N, C) |
| 36 | + (torch.rand(10, 6), torch.randint(0, 6, size=(10,)), 1), |
| 37 | + (torch.rand(10, 4), torch.randint(0, 4, size=(10,)), 1), |
| 38 | + # updated batches |
| 39 | + (torch.rand(50, 6), torch.randint(0, 6, size=(50,)), 16), |
| 40 | + (torch.rand(50, 4), torch.randint(0, 4, size=(50,)), 16), |
| 41 | + # Multiclass input data of shape (N, L) and (N, C, L) |
| 42 | + (torch.rand(10, 5, 8), torch.randint(0, 5, size=(10, 8)), 1), |
| 43 | + (torch.rand(10, 8, 12), torch.randint(0, 8, size=(10, 12)), 1), |
| 44 | + # updated batches |
| 45 | + (torch.rand(50, 5, 8), torch.randint(0, 5, size=(50, 8)), 16), |
| 46 | + (torch.rand(50, 8, 12), torch.randint(0, 8, size=(50, 12)), 16), |
| 47 | + # Multiclass input data of shape (N, H, W, ...) and (N, C, H, W, ...) |
| 48 | + (torch.rand(10, 5, 18, 16), torch.randint(0, 5, size=(10, 18, 16)), 1), |
| 49 | + (torch.rand(10, 7, 20, 12), torch.randint(0, 7, size=(10, 20, 12)), 1), |
| 50 | + # updated batches |
| 51 | + (torch.rand(50, 5, 18, 16), torch.randint(0, 5, size=(50, 18, 16)), 16), |
| 52 | + (torch.rand(50, 7, 20, 12), torch.randint(0, 7, size=(50, 20, 12)), 16), |
| 53 | + # Corner case with all zeros predictions |
| 54 | + (torch.zeros(size=(10, 6)), torch.randint(0, 6, size=(10,)), 1), |
| 55 | + (torch.zeros(size=(10, 4)), torch.randint(0, 4, size=(10,)), 1), |
| 56 | + ][request.param] |
| 57 | + |
| 58 | + |
| 59 | +@pytest.fixture(params=range(14)) |
| 60 | +def test_data_multilabel(request): |
| 61 | + return [ |
| 62 | + # Multilabel input data of shape (N, C) |
| 63 | + (torch.randint(0, 2, size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), |
| 64 | + (torch.randint(0, 2, size=(10, 4)), torch.randint(0, 2, size=(10, 4)), 1), |
| 65 | + # updated batches |
| 66 | + (torch.randint(0, 2, size=(50, 5)), torch.randint(0, 2, size=(50, 5)), 16), |
| 67 | + (torch.randint(0, 2, size=(50, 4)), torch.randint(0, 2, size=(50, 4)), 16), |
| 68 | + # Multilabel input data of shape (N, C, L) |
| 69 | + (torch.randint(0, 2, size=(10, 5, 10)), torch.randint(0, 2, size=(10, 5, 10)), 1), |
| 70 | + (torch.randint(0, 2, size=(10, 4, 10)), torch.randint(0, 2, size=(10, 4, 10)), 1), |
| 71 | + # updated batches |
| 72 | + (torch.randint(0, 2, size=(50, 5, 10)), torch.randint(0, 2, size=(50, 5, 10)), 16), |
| 73 | + (torch.randint(0, 2, size=(50, 4, 10)), torch.randint(0, 2, size=(50, 4, 10)), 16), |
| 74 | + # Multilabel input data of shape (N, C, H, W) |
| 75 | + (torch.randint(0, 2, size=(10, 5, 18, 16)), torch.randint(0, 2, size=(10, 5, 18, 16)), 1), |
| 76 | + (torch.randint(0, 2, size=(10, 4, 20, 23)), torch.randint(0, 2, size=(10, 4, 20, 23)), 1), |
| 77 | + # updated batches |
| 78 | + (torch.randint(0, 2, size=(50, 5, 18, 16)), torch.randint(0, 2, size=(50, 5, 18, 16)), 16), |
| 79 | + (torch.randint(0, 2, size=(50, 4, 20, 23)), torch.randint(0, 2, size=(50, 4, 20, 23)), 16), |
| 80 | + # Corner case with all zeros predictions |
| 81 | + (torch.zeros(size=(10, 5)), torch.randint(0, 2, size=(10, 5)), 1), |
| 82 | + (torch.zeros(size=(10, 4)), torch.randint(0, 2, size=(10, 4)), 1), |
| 83 | + ][request.param] |
0 commit comments