Skip to content

Commit 61b67bb

Browse files
committed
Fixed failing tests on pytorch nightly using torch.load
1 parent 3c5e213 commit 61b67bb

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

tests/ignite/engine/test_deterministic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99
import torch
1010
import torch.nn as nn
11+
from packaging.version import Version
1112
from torch.optim import SGD
1213
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
1314

@@ -737,7 +738,11 @@ def write_data_grads_weights(e):
737738
grad_norms.append([i, total[1]] + out2)
738739

739740
if sd is not None:
740-
sd = torch.load(sd)
741+
if Version(torch.__version__) >= Version("1.13.0"):
742+
kwargs = {"weights_only": False}
743+
else:
744+
kwargs = {}
745+
sd = torch.load(sd, **kwargs)
741746
model.load_state_dict(sd[0])
742747
opt.load_state_dict(sd[1])
743748
from ignite.engine.deterministic import _repr_rng_state

tests/ignite/handlers/test_state_param_scheduler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,11 @@ def test_torch_save_load(dirname):
295295

296296
filepath = Path(dirname) / "dummy_lambda_state_parameter_scheduler.pt"
297297
torch.save(lambda_state_parameter_scheduler, filepath)
298-
loaded_lambda_state_parameter_scheduler = torch.load(filepath)
298+
if Version(torch.__version__) >= Version("1.13.0"):
299+
kwargs = {"weights_only": False}
300+
else:
301+
kwargs = {}
302+
loaded_lambda_state_parameter_scheduler = torch.load(filepath, **kwargs)
299303

300304
engine1 = Engine(lambda e, b: None)
301305
lambda_state_parameter_scheduler.attach(engine1, Events.EPOCH_COMPLETED)

0 commit comments

Comments
 (0)