diff --git a/pandas/io/sql.py b/pandas/io/sql.py index d4aab05b22adf..aa4bcd8b1565a 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -7,7 +7,7 @@ from datetime import date, datetime, time from functools import partial import re -from typing import Any, Dict, Iterator, List, Optional, Sequence, Union, overload +from typing import Any, Dict, Iterator, List, Optional, Sequence, Union, cast, overload import warnings import numpy as np @@ -383,6 +383,8 @@ def read_sql_query( Data type for data or columns. E.g. np.float64 or {‘a’: np.float64, ‘b’: np.int32, ‘c’: ‘Int64’} + .. versionadded:: 1.3.0 + Returns ------- DataFrame or Iterator[DataFrame] @@ -609,7 +611,7 @@ def to_sql( index: bool = True, index_label=None, chunksize: Optional[int] = None, - dtype=None, + dtype: Optional[DtypeArg] = None, method: Optional[str] = None, ) -> None: """ @@ -768,7 +770,7 @@ def __init__( index_label=None, schema=None, keys=None, - dtype=None, + dtype: Optional[DtypeArg] = None, ): self.name = name self.pd_sql = pandas_sql_engine @@ -1108,9 +1110,11 @@ def _harmonize_columns(self, parse_dates=None): def _sqlalchemy_type(self, col): - dtype = self.dtype or {} - if col.name in dtype: - return self.dtype[col.name] + dtype: DtypeArg = self.dtype or {} + if is_dict_like(dtype): + dtype = cast(dict, dtype) + if col.name in dtype: + return dtype[col.name] # Infer type of column, while ignoring missing values. # Needed for inserting typed data containing NULLs, GH 8778. @@ -1209,7 +1213,18 @@ def read_sql(self, *args, **kwargs): "connectable or sqlite connection" ) - def to_sql(self, *args, **kwargs): + def to_sql( + self, + frame, + name, + if_exists="fail", + index=True, + index_label=None, + schema=None, + chunksize=None, + dtype: Optional[DtypeArg] = None, + method=None, + ): raise ValueError( "PandasSQL must be created with an SQLAlchemy " "connectable or sqlite connection" @@ -1436,7 +1451,7 @@ def to_sql( index_label=None, schema=None, chunksize=None, - dtype=None, + dtype: Optional[DtypeArg] = None, method=None, ): """ @@ -1480,10 +1495,12 @@ def to_sql( .. versionadded:: 0.24.0 """ - if dtype and not is_dict_like(dtype): - dtype = {col_name: dtype for col_name in frame} + if dtype: + if not is_dict_like(dtype): + dtype = {col_name: dtype for col_name in frame} + else: + dtype = cast(dict, dtype) - if dtype is not None: from sqlalchemy.types import TypeEngine, to_instance for col, my_type in dtype.items(): @@ -1569,7 +1586,7 @@ def _create_sql_schema( frame: DataFrame, table_name: str, keys: Optional[List[str]] = None, - dtype: Optional[dict] = None, + dtype: Optional[DtypeArg] = None, schema: Optional[str] = None, ): table = SQLTable( @@ -1740,9 +1757,11 @@ def _create_table_setup(self): return create_stmts def _sql_type_name(self, col): - dtype = self.dtype or {} - if col.name in dtype: - return dtype[col.name] + dtype: DtypeArg = self.dtype or {} + if is_dict_like(dtype): + dtype = cast(dict, dtype) + if col.name in dtype: + return dtype[col.name] # Infer type of column, while ignoring missing values. # Needed for inserting typed data containing NULLs, GH 8778. @@ -1901,7 +1920,7 @@ def to_sql( index_label=None, schema=None, chunksize=None, - dtype=None, + dtype: Optional[DtypeArg] = None, method=None, ): """ @@ -1944,10 +1963,12 @@ def to_sql( .. versionadded:: 0.24.0 """ - if dtype and not is_dict_like(dtype): - dtype = {col_name: dtype for col_name in frame} + if dtype: + if not is_dict_like(dtype): + dtype = {col_name: dtype for col_name in frame} + else: + dtype = cast(dict, dtype) - if dtype is not None: for col, my_type in dtype.items(): if not isinstance(my_type, str): raise ValueError(f"{col} ({my_type}) not a string") @@ -1986,7 +2007,7 @@ def _create_sql_schema( frame, table_name: str, keys=None, - dtype=None, + dtype: Optional[DtypeArg] = None, schema: Optional[str] = None, ): table = SQLiteTable( @@ -2002,7 +2023,12 @@ def _create_sql_schema( def get_schema( - frame, name: str, keys=None, con=None, dtype=None, schema: Optional[str] = None + frame, + name: str, + keys=None, + con=None, + dtype: Optional[DtypeArg] = None, + schema: Optional[str] = None, ): """ Get the SQL db table schema for the given frame.