Skip to content

Commit 3d14748

Browse files
GH1255 Improve index typing for Series (#1261)
* GH1255 Initial commit * GH1255 Initial commit * GH1255 Initial commit * GH1255 Initial commit * GH1255 Initial commit
1 parent 494b3b0 commit 3d14748

File tree

3 files changed

+41
-20
lines changed

3 files changed

+41
-20
lines changed

pandas-stubs/_typing.pyi

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ from collections.abc import (
33
Callable,
44
Hashable,
55
Iterator,
6+
KeysView,
67
Mapping,
78
MutableSequence,
89
Sequence,
@@ -796,9 +797,6 @@ SliceType: TypeAlias = Hashable | None
796797

797798
num: TypeAlias = complex
798799

799-
# AxesData is used for data for Index
800-
AxesData: TypeAlias = Axes | dict
801-
802800
DtypeNp = TypeVar("DtypeNp", bound=np.dtype[np.generic])
803801
KeysArgType: TypeAlias = Any
804802
ListLikeT = TypeVar("ListLikeT", bound=ListLike)
@@ -853,6 +851,9 @@ IndexingInt: TypeAlias = (
853851
int | np.int_ | np.integer | np.unsignedinteger | np.signedinteger | np.int8
854852
)
855853

854+
# AxesData is used for data for Index
855+
AxesData: TypeAlias = Mapping[S3, Any] | Axes | KeysView
856+
856857
# Any plain Python or numpy function
857858
Function: TypeAlias = np.ufunc | Callable[..., Any]
858859
# Use a distinct HashableT in shared types to avoid conflicts with

pandas-stubs/core/series.pyi

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ from pandas._typing import (
111111
AnyArrayLike,
112112
ArrayLike,
113113
Axes,
114+
AxesData,
114115
Axis,
115116
AxisColumn,
116117
AxisIndex,
@@ -252,7 +253,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
252253
def __new__(
253254
cls,
254255
data: npt.NDArray[np.float64],
255-
index: Axes | None = ...,
256+
index: AxesData | None = ...,
256257
dtype: Dtype = ...,
257258
name: Hashable = ...,
258259
copy: bool = ...,
@@ -261,7 +262,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
261262
def __new__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
262263
cls,
263264
data: Sequence[Never],
264-
index: Axes | None = ...,
265+
index: AxesData | None = ...,
265266
dtype: Dtype = ...,
266267
name: Hashable = ...,
267268
copy: bool = ...,
@@ -270,7 +271,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
270271
def __new__(
271272
cls,
272273
data: Sequence[list[_str]],
273-
index: Axes | None = ...,
274+
index: AxesData | None = ...,
274275
dtype: Dtype = ...,
275276
name: Hashable = ...,
276277
copy: bool = ...,
@@ -279,7 +280,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
279280
def __new__(
280281
cls,
281282
data: Sequence[_str],
282-
index: Axes | None = ...,
283+
index: AxesData | None = ...,
283284
dtype: Dtype = ...,
284285
name: Hashable = ...,
285286
copy: bool = ...,
@@ -295,7 +296,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
295296
| datetime
296297
| date
297298
),
298-
index: Axes | None = ...,
299+
index: AxesData | None = ...,
299300
dtype: TimestampDtypeArg = ...,
300301
name: Hashable = ...,
301302
copy: bool = ...,
@@ -304,7 +305,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
304305
def __new__(
305306
cls,
306307
data: _ListLike,
307-
index: Axes | None = ...,
308+
index: AxesData | None = ...,
308309
*,
309310
dtype: TimestampDtypeArg,
310311
name: Hashable = ...,
@@ -314,7 +315,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
314315
def __new__(
315316
cls,
316317
data: PeriodIndex | Sequence[Period],
317-
index: Axes | None = ...,
318+
index: AxesData | None = ...,
318319
dtype: PeriodDtype = ...,
319320
name: Hashable = ...,
320321
copy: bool = ...,
@@ -329,7 +330,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
329330
| np.timedelta64
330331
| timedelta
331332
),
332-
index: Axes | None = ...,
333+
index: AxesData | None = ...,
333334
dtype: TimedeltaDtypeArg = ...,
334335
name: Hashable = ...,
335336
copy: bool = ...,
@@ -343,7 +344,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
343344
| Sequence[Interval[_OrderableT]]
344345
| dict[HashableT1, Interval[_OrderableT]]
345346
),
346-
index: Axes | None = ...,
347+
index: AxesData | None = ...,
347348
dtype: Literal["Interval"] = ...,
348349
name: Hashable = ...,
349350
copy: bool = ...,
@@ -352,7 +353,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
352353
def __new__( # type: ignore[overload-overlap]
353354
cls,
354355
data: Scalar | _ListLike | dict[HashableT1, Any] | None,
355-
index: Axes | None = ...,
356+
index: AxesData | None = ...,
356357
*,
357358
dtype: type[S1],
358359
name: Hashable = ...,
@@ -362,7 +363,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
362363
def __new__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
363364
cls,
364365
data: Sequence[bool],
365-
index: Axes | None = ...,
366+
index: AxesData | None = ...,
366367
dtype: Dtype = ...,
367368
name: Hashable = ...,
368369
copy: bool = ...,
@@ -371,7 +372,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
371372
def __new__( # type: ignore[overload-overlap]
372373
cls,
373374
data: Sequence[int],
374-
index: Axes | None = ...,
375+
index: AxesData | None = ...,
375376
dtype: Dtype = ...,
376377
name: Hashable = ...,
377378
copy: bool = ...,
@@ -380,7 +381,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
380381
def __new__(
381382
cls,
382383
data: Sequence[float],
383-
index: Axes | None = ...,
384+
index: AxesData | None = ...,
384385
dtype: Dtype = ...,
385386
name: Hashable = ...,
386387
copy: bool = ...,
@@ -389,7 +390,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
389390
def __new__( # type: ignore[overload-cannot-match] # pyright: ignore[reportOverlappingOverload]
390391
cls,
391392
data: Sequence[int | float],
392-
index: Axes | None = ...,
393+
index: AxesData | None = ...,
393394
dtype: Dtype = ...,
394395
name: Hashable = ...,
395396
copy: bool = ...,
@@ -398,7 +399,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
398399
def __new__(
399400
cls,
400401
data: S1 | _ListLike[S1] | dict[HashableT1, S1] | dict_keys[S1, Any],
401-
index: Axes | None = ...,
402+
index: AxesData | None = ...,
402403
dtype: Dtype = ...,
403404
name: Hashable = ...,
404405
copy: bool = ...,
@@ -415,7 +416,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
415416
| NAType
416417
| None
417418
) = ...,
418-
index: Axes | None = ...,
419+
index: AxesData | None = ...,
419420
dtype: Dtype = ...,
420421
name: Hashable = ...,
421422
copy: bool = ...,

tests/test_series.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,7 @@ def test_types_scalar_arithmetic() -> None:
868868

869869

870870
def test_types_complex_arithmetic() -> None:
871-
# GH 103
871+
"""Test adding complex number to pd.Series[float] GH 103."""
872872
c = 1 + 1j
873873
s = pd.Series([1.0, 2.0, 3.0])
874874
x = s + c
@@ -3922,3 +3922,22 @@ def test_series_unstack() -> None:
39223922
),
39233923
pd.DataFrame,
39243924
)
3925+
3926+
3927+
def test_series_index_type() -> None:
3928+
index = {"a": 3, "c": 4}
3929+
lst = [1, 2]
3930+
3931+
check(
3932+
assert_type(pd.Series(lst, index=index), "pd.Series[int]"),
3933+
pd.Series,
3934+
np.integer,
3935+
)
3936+
check(
3937+
assert_type(pd.Series([1, 2], index=index.keys()), "pd.Series[int]"),
3938+
pd.Series,
3939+
np.integer,
3940+
)
3941+
3942+
if TYPE_CHECKING_INVALID_USAGE:
3943+
t = pd.Series([1, 2], index="ab") # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportArgumentType]

0 commit comments

Comments
 (0)