Skip to content

Commit 4568503

Browse files
committed
Implement alternative version of SVD derivative which should be more stable
1 parent beb6ae7 commit 4568503

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

varipeps/utils/svd.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,33 @@ def _svd_jvp_rule(primals, tangents):
5757
dS = Ut @ dA @ V
5858
ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))
5959

60-
s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
61-
# s_diffs = jnp.where(s_diffs / (s[0] ** 2) >= 1e-12, s_diffs, 0)
60+
s_sums = s_dim + _T(s_dim)
61+
s_diffs = s_dim - _T(s_dim)
62+
s_diffs = jnp.where(jnp.abs(s_diffs / s[0]) >= 1e-12, s_diffs, 0)
6263
s_diffs_zeros = jnp.ones((), dtype=A.dtype) * (
6364
s_diffs == 0.0
6465
) # is 1. where s_diffs is 0. and is 0. everywhere else
6566
s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2))
6667
F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
67-
dSS = s_dim.astype(A.dtype) * dS # dS.dot(jnp.diag(s))
68-
SdS = _T(s_dim.astype(A.dtype)) * dS # jnp.diag(s).dot(dS)
68+
dSS = dS * (s_dim / s_sums).astype(A.dtype) # dS.dot(s_j / (s_i + s_j))
69+
SdS = (_T(s_dim) / s_sums).astype(A.dtype) * dS # (s_i / (s_i + s_j)).dot(dS)
70+
71+
# s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
72+
# # s_diffs = jnp.where(s_diffs / (s[0] ** 2) >= 1e-12, s_diffs, 0)
73+
# s_diffs_zeros = jnp.ones((), dtype=A.dtype) * (
74+
# s_diffs == 0.0
75+
# ) # is 1. where s_diffs is 0. and is 0. everywhere else
76+
# s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2))
77+
# F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
78+
# dSS = s_dim.astype(A.dtype) * dS # dS.dot(jnp.diag(s))
79+
# SdS = _T(s_dim.astype(A.dtype)) * dS # jnp.diag(s).dot(dS)
6980

7081
s_zeros = (s == 0).astype(s.dtype)
7182
s_inv = 1 / (s + s_zeros) - s_zeros
7283
s_inv_mat = jnp.vectorize(jnp.diag, signature="(k)->(k,k)")(s_inv)
7384
dUdV_diag = 0.5 * (dS - _H(dS)) * s_inv_mat.astype(A.dtype)
74-
dU = U @ (F.astype(A.dtype) * (dSS + _H(dSS)) + dUdV_diag)
75-
dV = V @ (F.astype(A.dtype) * (SdS + _H(SdS)))
85+
dU = U @ (F.astype(A.dtype) * (dSS + _H(dSS)) + 0.5 * dUdV_diag)
86+
dV = V @ (F.astype(A.dtype) * (SdS + _H(SdS)) + 0.5 * dUdV_diag)
7687

7788
m, n = A.shape[-2:]
7889
if m > n:

0 commit comments

Comments
 (0)