File tree Expand file tree Collapse file tree 2 files changed +9
-0
lines changed Expand file tree Collapse file tree 2 files changed +9
-0
lines changed Original file line number Diff line number Diff line change 2
2
3
3
from tqdm_loggable .auto import tqdm
4
4
5
+ import jax
5
6
import jax .numpy as jnp
6
7
from jax import jit
7
8
from jax .flatten_util import ravel_pytree
@@ -1119,6 +1120,9 @@ def line_search(
1119
1120
1120
1121
count += 1
1121
1122
1123
+ if new_unitcell [0 , 0 ][0 ][0 ].chi != unitcell [0 , 0 ][0 ][0 ].chi :
1124
+ jax .clear_caches ()
1125
+
1122
1126
if count == varipeps_config .line_search_max_steps :
1123
1127
raise NoSuitableStepSizeError (f"Count { count } , Last alpha { alpha } " )
1124
1128
Original file line number Diff line number Diff line change 8
8
9
9
from tqdm_loggable .auto import tqdm
10
10
11
+ import jax
11
12
from jax import jit
12
13
import jax .numpy as jnp
13
14
from jax .lax import scan
@@ -405,6 +406,7 @@ def random_noise(a):
405
406
while count < varipeps_config .optimizer_max_steps :
406
407
runtime_start = time .perf_counter ()
407
408
409
+ chi_before_ctmrg = working_unitcell [0 , 0 ][0 ][0 ].chi
408
410
try :
409
411
if varipeps_config .ad_use_custom_vjp :
410
412
(
@@ -498,6 +500,9 @@ def random_noise(a):
498
500
499
501
continue
500
502
503
+ if working_unitcell [0 , 0 ][0 ][0 ].chi != chi_before_ctmrg :
504
+ jax .clear_caches ()
505
+
501
506
working_gradient = [elem .conj () for elem in working_gradient_seq ]
502
507
503
508
if signal_reset_descent_dir :
You can’t perform that action at this time.
0 commit comments