Skip to content

Commit cfb0484

Browse files
committed
Ruff autofixes for type hints with 3.11+ features
1 parent 252a8b1 commit cfb0484

File tree

15 files changed

+60
-65
lines changed

15 files changed

+60
-65
lines changed

benchmarks/benchmarking.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,10 @@ def _format_results_entry(results_entry: dict) -> str:
252252

253253
def _dict_product(dicts: dict[str, Iterable[Any]]) -> Iterable[dict[str, Any]]:
254254
"""Generator corresponding to Cartesian product of dictionaries."""
255-
return (dict(zip(dicts.keys(), values)) for values in product(*dicts.values()))
255+
return (
256+
dict(zip(dicts.keys(), values, strict=False))
257+
for values in product(*dicts.values())
258+
)
256259

257260

258261
def _parse_value(value: str) -> Any:

benchmarks/plotting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,10 @@ def plot_results_against_bandlimit(
141141
squeeze=False,
142142
)
143143
axes = axes.T if functions_along_columns else axes
144-
for axes_row, function in zip(axes, functions):
144+
for axes_row, function in zip(axes, functions, strict=False):
145145
results = benchmark_results["results"][function]
146146
l_values = np.array([r["parameters"]["L"] for r in results])
147-
for ax, measurement in zip(axes_row, measurements):
147+
for ax, measurement in zip(axes_row, measurements, strict=False):
148148
plot_function, label = _measurement_plot_functions_and_labels[measurement]
149149
try:
150150
plot_function(ax, "L", l_values, results)

s2fft/precompute_transforms/construct.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from typing import Tuple
21
from warnings import warn
32

43
import jax
@@ -612,7 +611,7 @@ def wigner_kernel_jax(
612611
wigner_kernel_torch = torch_wrapper.wrap_as_torch_function(wigner_kernel_jax)
613612

614613

615-
def fourier_wigner_kernel(L: int) -> Tuple[np.ndarray, np.ndarray]:
614+
def fourier_wigner_kernel(L: int) -> tuple[np.ndarray, np.ndarray]:
616615
"""
617616
Computes Fourier coefficients of the reduced Wigner d-functions and quadrature
618617
weights upsampled for the forward Fourier-Wigner transform.
@@ -640,7 +639,7 @@ def fourier_wigner_kernel(L: int) -> Tuple[np.ndarray, np.ndarray]:
640639
return deltas, w
641640

642641

643-
def fourier_wigner_kernel_jax(L: int) -> Tuple[jnp.ndarray, jnp.ndarray]:
642+
def fourier_wigner_kernel_jax(L: int) -> tuple[jnp.ndarray, jnp.ndarray]:
644643
"""
645644
Computes Fourier coefficients of the reduced Wigner d-functions and quadrature
646645
weights upsampled for the forward Fourier-Wigner transform (JAX implementation).

s2fft/precompute_transforms/custom_ops.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from functools import partial
2-
from typing import Tuple
32

43
import jax.numpy as jnp
54
import numpy as np
@@ -9,7 +8,7 @@
98
def wigner_subset_to_s2(
109
flmn: np.ndarray,
1110
spins: np.ndarray,
12-
DW: Tuple[np.ndarray, np.ndarray],
11+
DW: tuple[np.ndarray, np.ndarray],
1312
L: int,
1413
sampling: str = "mw",
1514
) -> np.ndarray:
@@ -91,7 +90,7 @@ def wigner_subset_to_s2(
9190
def wigner_subset_to_s2_jax(
9291
flmn: jnp.ndarray,
9392
spins: jnp.ndarray,
94-
DW: Tuple[jnp.ndarray, jnp.ndarray],
93+
DW: tuple[jnp.ndarray, jnp.ndarray],
9594
L: int,
9695
sampling: str = "mw",
9796
) -> jnp.ndarray:
@@ -173,7 +172,7 @@ def wigner_subset_to_s2_jax(
173172
def so3_to_wigner_subset(
174173
f: np.ndarray,
175174
spins: np.ndarray,
176-
DW: Tuple[np.ndarray, np.ndarray],
175+
DW: tuple[np.ndarray, np.ndarray],
177176
L: int,
178177
N: int,
179178
sampling: str = "mw",
@@ -214,7 +213,7 @@ def so3_to_wigner_subset(
214213
def so3_to_wigner_subset_jax(
215214
f: jnp.ndarray,
216215
spins: jnp.ndarray,
217-
DW: Tuple[jnp.ndarray, jnp.ndarray],
216+
DW: tuple[jnp.ndarray, jnp.ndarray],
218217
L: int,
219218
N: int,
220219
sampling: str = "mw",
@@ -257,7 +256,7 @@ def so3_to_wigner_subset_jax(
257256
def s2_to_wigner_subset(
258257
fs: np.ndarray,
259258
spins: np.ndarray,
260-
DW: Tuple[np.ndarray, np.ndarray],
259+
DW: tuple[np.ndarray, np.ndarray],
261260
L: int,
262261
sampling: str = "mw",
263262
) -> np.ndarray:
@@ -343,7 +342,7 @@ def s2_to_wigner_subset(
343342
def s2_to_wigner_subset_jax(
344343
fs: jnp.ndarray,
345344
spins: jnp.ndarray,
346-
DW: Tuple[jnp.ndarray, jnp.ndarray],
345+
DW: tuple[jnp.ndarray, jnp.ndarray],
347346
L: int,
348347
sampling: str = "mw",
349348
) -> jnp.ndarray:

s2fft/precompute_transforms/spherical.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from functools import partial
2-
from typing import Optional
32
from warnings import warn
43

54
import jax.numpy as jnp
@@ -21,11 +20,11 @@ def inverse(
2120
flm: np.ndarray,
2221
L: int,
2322
spin: int = 0,
24-
kernel: Optional[np.ndarray] = None,
23+
kernel: np.ndarray | None = None,
2524
sampling: str = "mw",
2625
reality: bool = False,
2726
method: str = "jax",
28-
nside: Optional[int] = None,
27+
nside: int | None = None,
2928
) -> np.ndarray:
3029
r"""
3130
Compute the inverse spherical harmonic transform via precompute.
@@ -228,11 +227,11 @@ def forward(
228227
f: np.ndarray,
229228
L: int,
230229
spin: int = 0,
231-
kernel: Optional[np.ndarray] = None,
230+
kernel: np.ndarray | None = None,
232231
sampling: str = "mw",
233232
reality: bool = False,
234233
method: str = "jax",
235-
nside: Optional[int] = None,
234+
nside: int | None = None,
236235
iter: int = 0,
237236
) -> np.ndarray:
238237
r"""

s2fft/recursions/price_mcewen.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import warnings
22
from functools import partial
3-
from typing import List
43

54
import jax.lax as lax
65
import jax.numpy as jnp
@@ -19,7 +18,7 @@ def generate_precomputes(
1918
nside: int = None,
2019
forward: bool = False,
2120
L_lower: int = 0,
22-
) -> List[np.ndarray]:
21+
) -> list[np.ndarray]:
2322
r"""
2423
Compute recursion coefficients with :math:`\mathcal{O}(L^3)` memory overhead.
2524
@@ -125,7 +124,7 @@ def generate_precomputes_jax(
125124
forward: bool = False,
126125
L_lower: int = 0,
127126
betas: jnp.ndarray = None,
128-
) -> List[jnp.ndarray]:
127+
) -> list[jnp.ndarray]:
129128
r"""
130129
Compute recursion coefficients with :math:`\mathcal{O}(L^2)` memory overhead.
131130
In practice one could compute these on-the-fly but the memory overhead is
@@ -264,7 +263,7 @@ def generate_precomputes_wigner(
264263
forward: bool = False,
265264
reality: bool = False,
266265
L_lower: int = 0,
267-
) -> List[List[np.ndarray]]:
266+
) -> list[list[np.ndarray]]:
268267
r"""
269268
Compute recursion coefficients with :math:`\mathcal{O}(L^2)` memory overhead.
270269
In practice one could compute these on-the-fly but the memory overhead is
@@ -316,7 +315,7 @@ def generate_precomputes_wigner_jax(
316315
forward: bool = False,
317316
reality: bool = False,
318317
L_lower: int = 0,
319-
) -> List[List[jnp.ndarray]]:
318+
) -> list[list[jnp.ndarray]]:
320319
r"""
321320
Compute recursion coefficients with :math:`\mathcal{O}(L^2)` memory overhead.
322321
In practice one could compute these on-the-fly but the memory overhead is

s2fft/sampling/s2_samples.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Tuple
2-
31
import numpy as np
42

53

@@ -125,7 +123,7 @@ def nphi_equiang(L: int, sampling: str = "mw") -> int:
125123
return 1
126124

127125

128-
def ftm_shape(L: int, sampling: str = "mw", nside: int = None) -> Tuple[int, int]:
126+
def ftm_shape(L: int, sampling: str = "mw", nside: int = None) -> tuple[int, int]:
129127
r"""
130128
Shape of intermediate array, before/after latitudinal step.
131129
@@ -445,7 +443,7 @@ def ring_phase_shift_hp(
445443
return np.exp(sign * 1j * np.arange(m_start_ind, L) * phi_offset)
446444

447445

448-
def f_shape(L: int = None, sampling: str = "mw", nside: int = None) -> Tuple[int]:
446+
def f_shape(L: int = None, sampling: str = "mw", nside: int = None) -> tuple[int]:
449447
r"""
450448
Shape of spherical signal.
451449
@@ -480,7 +478,7 @@ def f_shape(L: int = None, sampling: str = "mw", nside: int = None) -> Tuple[int
480478
return ntheta(L, sampling), nphi_equiang(L, sampling)
481479

482480

483-
def flm_shape(L: int) -> Tuple[int, int]:
481+
def flm_shape(L: int) -> tuple[int, int]:
484482
r"""
485483
Standard shape of harmonic coefficients.
486484

s2fft/sampling/so3_samples.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
from typing import Tuple
2-
31
import numpy as np
42

53
from s2fft.sampling import s2_samples as samples
64

75

86
def f_shape(
97
L: int, N: int, sampling: str = "mw", nside: int = None
10-
) -> Tuple[int, int, int]:
8+
) -> tuple[int, int, int]:
119
r"""
1210
Computes the pixel-space sampling shape for signal on the rotation group
1311
:math:`SO(3)`.
@@ -49,7 +47,7 @@ def f_shape(
4947
raise ValueError(f"Sampling scheme sampling={sampling} not supported")
5048

5149

52-
def flmn_shape(L: int, N: int) -> Tuple[int, int, int]:
50+
def flmn_shape(L: int, N: int) -> tuple[int, int, int]:
5351
r"""
5452
Computes the shape of Wigner coefficients for signal on the rotation group
5553
:math:`SO(3)`.
@@ -69,7 +67,7 @@ def flmn_shape(L: int, N: int) -> Tuple[int, int, int]:
6967

7068
def fnab_shape(
7169
L: int, N: int, sampling: str = "mw", nside: int = None
72-
) -> Tuple[int, int, int]:
70+
) -> tuple[int, int, int]:
7371
r"""
7472
Computes the shape of Wigner coefficients for signal on the rotation group
7573
:math:`SO(3)`.

s2fft/transforms/otf_recursions.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from functools import partial
2-
from typing import List
32

43
import jax.lax as lax
54
import jax.numpy as jnp
@@ -21,7 +20,7 @@ def inverse_latitudinal_step(
2120
nside: int,
2221
sampling: str = "mw",
2322
reality: bool = False,
24-
precomps: List = None,
23+
precomps: list = None,
2524
L_lower: int = 0,
2625
) -> np.ndarray:
2726
r"""
@@ -181,7 +180,7 @@ def inverse_latitudinal_step_jax(
181180
nside: int,
182181
sampling: str = "mw",
183182
reality: bool = False,
184-
precomps: List = None,
183+
precomps: list = None,
185184
spmd: bool = False,
186185
L_lower: int = 0,
187186
) -> jnp.ndarray:
@@ -438,7 +437,7 @@ def forward_latitudinal_step(
438437
nside: int,
439438
sampling: str = "mw",
440439
reality: bool = False,
441-
precomps: List = None,
440+
precomps: list = None,
442441
L_lower: int = 0,
443442
) -> np.ndarray:
444443
r"""
@@ -598,7 +597,7 @@ def forward_latitudinal_step_jax(
598597
nside: int,
599598
sampling: str = "mw",
600599
reality: bool = False,
601-
precomps: List = None,
600+
precomps: list = None,
602601
spmd: bool = False,
603602
L_lower: int = 0,
604603
) -> jnp.ndarray:

s2fft/transforms/spherical.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from functools import partial
2-
from typing import List, Optional
32

43
import jax.numpy as jnp
54
import numpy as np
@@ -27,7 +26,7 @@ def inverse(
2726
sampling: str = "mw",
2827
method: str = "numpy",
2928
reality: bool = False,
30-
precomps: List = None,
29+
precomps: list = None,
3130
spmd: bool = False,
3231
L_lower: int = 0,
3332
_ssht_backend: int = 1,
@@ -117,7 +116,7 @@ def inverse_numpy(
117116
nside: int = None,
118117
sampling: str = "mw",
119118
reality: bool = False,
120-
precomps: List = None,
119+
precomps: list = None,
121120
L_lower: int = 0,
122121
) -> np.ndarray:
123122
r"""
@@ -217,7 +216,7 @@ def inverse_jax(
217216
nside: int = None,
218217
sampling: str = "mw",
219218
reality: bool = False,
220-
precomps: List = None,
219+
precomps: list = None,
221220
spmd: bool = False,
222221
L_lower: int = 0,
223222
use_healpix_custom_primitive: bool = False,
@@ -354,14 +353,14 @@ def forward(
354353
f: np.ndarray,
355354
L: int,
356355
spin: int = 0,
357-
nside: Optional[int] = None,
356+
nside: int | None = None,
358357
sampling: str = "mw",
359358
method: str = "numpy",
360359
reality: bool = False,
361-
precomps: Optional[List] = None,
360+
precomps: list | None = None,
362361
spmd: bool = False,
363362
L_lower: int = 0,
364-
iter: Optional[int] = None,
363+
iter: int | None = None,
365364
_ssht_backend: int = 1,
366365
) -> np.ndarray:
367366
r"""
@@ -472,7 +471,7 @@ def forward_numpy(
472471
nside: int = None,
473472
sampling: str = "mw",
474473
reality: bool = False,
475-
precomps: List = None,
474+
precomps: list = None,
476475
L_lower: int = 0,
477476
) -> np.ndarray:
478477
r"""
@@ -597,7 +596,7 @@ def forward_jax(
597596
nside: int = None,
598597
sampling: str = "mw",
599598
reality: bool = False,
600-
precomps: List = None,
599+
precomps: list = None,
601600
spmd: bool = False,
602601
L_lower: int = 0,
603602
use_healpix_custom_primitive: bool = False,

0 commit comments

Comments
 (0)