@@ -57,22 +57,33 @@ def _svd_jvp_rule(primals, tangents):
57
57
dS = Ut @ dA @ V
58
58
ds = jnp .real (jnp .diagonal (dS , 0 , - 2 , - 1 ))
59
59
60
- s_diffs = (s_dim + _T (s_dim )) * (s_dim - _T (s_dim ))
61
- # s_diffs = jnp.where(s_diffs / (s[0] ** 2) >= 1e-12, s_diffs, 0)
60
+ s_sums = s_dim + _T (s_dim )
61
+ s_diffs = s_dim - _T (s_dim )
62
+ s_diffs = jnp .where (jnp .abs (s_diffs / s [0 ]) >= 1e-12 , s_diffs , 0 )
62
63
s_diffs_zeros = jnp .ones ((), dtype = A .dtype ) * (
63
64
s_diffs == 0.0
64
65
) # is 1. where s_diffs is 0. and is 0. everywhere else
65
66
s_diffs_zeros = lax .expand_dims (s_diffs_zeros , range (s_diffs .ndim - 2 ))
66
67
F = 1 / (s_diffs + s_diffs_zeros ) - s_diffs_zeros
67
- dSS = s_dim .astype (A .dtype ) * dS # dS.dot(jnp.diag(s))
68
- SdS = _T (s_dim .astype (A .dtype )) * dS # jnp.diag(s).dot(dS)
68
+ dSS = dS * (s_dim / s_sums ).astype (A .dtype ) # dS.dot(s_j / (s_i + s_j))
69
+ SdS = (_T (s_dim ) / s_sums ).astype (A .dtype ) * dS # (s_i / (s_i + s_j)).dot(dS)
70
+
71
+ # s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
72
+ # # s_diffs = jnp.where(s_diffs / (s[0] ** 2) >= 1e-12, s_diffs, 0)
73
+ # s_diffs_zeros = jnp.ones((), dtype=A.dtype) * (
74
+ # s_diffs == 0.0
75
+ # ) # is 1. where s_diffs is 0. and is 0. everywhere else
76
+ # s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2))
77
+ # F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
78
+ # dSS = s_dim.astype(A.dtype) * dS # dS.dot(jnp.diag(s))
79
+ # SdS = _T(s_dim.astype(A.dtype)) * dS # jnp.diag(s).dot(dS)
69
80
70
81
s_zeros = (s == 0 ).astype (s .dtype )
71
82
s_inv = 1 / (s + s_zeros ) - s_zeros
72
83
s_inv_mat = jnp .vectorize (jnp .diag , signature = "(k)->(k,k)" )(s_inv )
73
84
dUdV_diag = 0.5 * (dS - _H (dS )) * s_inv_mat .astype (A .dtype )
74
- dU = U @ (F .astype (A .dtype ) * (dSS + _H (dSS )) + dUdV_diag )
75
- dV = V @ (F .astype (A .dtype ) * (SdS + _H (SdS )))
85
+ dU = U @ (F .astype (A .dtype ) * (dSS + _H (dSS )) + 0.5 * dUdV_diag )
86
+ dV = V @ (F .astype (A .dtype ) * (SdS + _H (SdS )) + 0.5 * dUdV_diag )
76
87
77
88
m , n = A .shape [- 2 :]
78
89
if m > n :
0 commit comments