We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4568503 commit d7ba3ddCopy full SHA for d7ba3dd
varipeps/utils/svd.py
@@ -58,6 +58,7 @@ def _svd_jvp_rule(primals, tangents):
58
ds = jnp.real(jnp.diagonal(dS, 0, -2, -1))
59
60
s_sums = s_dim + _T(s_dim)
61
+ s_sums = jnp.where(s_sums > 0, s_sums, 1)
62
s_diffs = s_dim - _T(s_dim)
63
s_diffs = jnp.where(jnp.abs(s_diffs / s[0]) >= 1e-12, s_diffs, 0)
64
s_diffs_zeros = jnp.ones((), dtype=A.dtype) * (
0 commit comments