Skip to content

Commit fd85588

Browse files
authored
adds available_device to test_entropy.py #3335 (#3358)
* adds available_device to test_entropy.py #3335 * fix test_case producing large values
1 parent b5e9dae commit fd85588

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

tests/ignite/metrics/test_entropy.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ def test_case(request):
4545

4646

4747
@pytest.mark.parametrize("n_times", range(5))
48-
def test_compute(n_times, test_case):
49-
ent = Entropy()
48+
def test_compute(n_times, test_case, available_device):
49+
ent = Entropy(device=available_device)
50+
assert ent._device == torch.device(available_device)
5051

5152
y_pred, y, batch_size = test_case
5253

@@ -59,14 +60,15 @@ def test_compute(n_times, test_case):
5960
else:
6061
ent.update((y_pred, y))
6162

62-
np_res = np_entropy(y_pred.numpy())
63+
np_res = np_entropy(y_pred.cpu().numpy())
6364

6465
assert isinstance(ent.compute(), float)
6566
assert pytest.approx(ent.compute()) == np_res
6667

6768

68-
def test_accumulator_detached():
69-
ent = Entropy()
69+
def test_accumulator_detached(available_device):
70+
ent = Entropy(device=available_device)
71+
assert ent._device == torch.device(available_device)
7072

7173
y_pred = torch.tensor([[2.0, 3.0], [-2.0, -1.0]], requires_grad=True)
7274
y = torch.zeros(2)

tests/ignite/metrics/test_hsic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_case(request) -> Tuple[Tensor, Tensor, int]:
7777
x = torch.randn(N, 5)
7878
y = x @ torch.normal(0.0, math.pi, size=(5, 3))
7979
y = (
80-
torch.stack([torch.sin(y[:, 0]), torch.cos(y[:, 1]), torch.exp(y[:, 2])], dim=1)
80+
torch.stack([torch.sin(y[:, 0]), torch.cos(y[:, 1]), torch.exp(y[:, 2]) / 10], dim=1)
8181
+ torch.randn_like(y) * 1e-4
8282
)
8383

0 commit comments

Comments
 (0)