Skip to content

Commit 5d3e92b

Browse files
committed
Try to avoid memory leaks by clearing jax caches if CTMRG heuristics forces a lot of JITs
1 parent 9bb2c94 commit 5d3e92b

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

varipeps/optimization/line_search.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from tqdm_loggable.auto import tqdm
44

5+
import jax
56
import jax.numpy as jnp
67
from jax import jit
78
from jax.flatten_util import ravel_pytree
@@ -1119,6 +1120,9 @@ def line_search(
11191120

11201121
count += 1
11211122

1123+
if new_unitcell[0, 0][0][0].chi != unitcell[0, 0][0][0].chi:
1124+
jax.clear_caches()
1125+
11221126
if count == varipeps_config.line_search_max_steps:
11231127
raise NoSuitableStepSizeError(f"Count {count}, Last alpha {alpha}")
11241128

varipeps/optimization/optimizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from tqdm_loggable.auto import tqdm
1010

11+
import jax
1112
from jax import jit
1213
import jax.numpy as jnp
1314
from jax.lax import scan
@@ -405,6 +406,7 @@ def random_noise(a):
405406
while count < varipeps_config.optimizer_max_steps:
406407
runtime_start = time.perf_counter()
407408

409+
chi_before_ctmrg = working_unitcell[0, 0][0][0].chi
408410
try:
409411
if varipeps_config.ad_use_custom_vjp:
410412
(
@@ -498,6 +500,9 @@ def random_noise(a):
498500

499501
continue
500502

503+
if working_unitcell[0, 0][0][0].chi != chi_before_ctmrg:
504+
jax.clear_caches()
505+
501506
working_gradient = [elem.conj() for elem in working_gradient_seq]
502507

503508
if signal_reset_descent_dir:

0 commit comments

Comments
 (0)