Skip to content

REF: share comparison methods for DTA/TDA/PA #30751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 81 additions & 2 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pandas._libs.tslibs.timedeltas import Timedelta, delta_to_nanoseconds
from pandas._libs.tslibs.timestamps import RoundTo, round_nsint64
from pandas._typing import DatetimeLikeScalar
from pandas.compat import set_function_name
from pandas.compat.numpy import function as nv
from pandas.errors import AbstractMethodError, NullFrequencyError, PerformanceWarning
from pandas.util._decorators import Appender, Substitution
Expand All @@ -37,19 +38,94 @@
from pandas.core.dtypes.inference import is_array_like
from pandas.core.dtypes.missing import is_valid_nat_for_dtype, isna

from pandas.core import missing, nanops
from pandas.core import missing, nanops, ops
from pandas.core.algorithms import checked_add_with_arr, take, unique1d, value_counts
import pandas.core.common as com
from pandas.core.indexers import check_bool_array_indexer
from pandas.core.ops.common import unpack_zerodim_and_defer
from pandas.core.ops.invalid import make_invalid_op
from pandas.core.ops.invalid import invalid_comparison, make_invalid_op

from pandas.tseries import frequencies
from pandas.tseries.offsets import DateOffset, Tick

from .base import ExtensionArray, ExtensionOpsMixin


def _datetimelike_array_cmp(cls, op):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't an exact copy of dt_array_cmp right? If not can you help point out where integrates from period / td methods?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i put some comments in when moving this, are there specific parts that arent clear?

"""
Wrap comparison operations to convert Timestamp/Timedelta/Period-like to
boxed scalars/arrays.
"""
opname = f"__{op.__name__}__"
nat_result = opname == "__ne__"

@unpack_zerodim_and_defer(opname)
def wrapper(self, other):

if isinstance(other, str):
try:
# GH#18435 strings get a pass from tzawareness compat
other = self._scalar_from_string(other)
except ValueError:
# failed to parse as Timestamp/Timedelta/Period
return invalid_comparison(self, other, op)

if isinstance(other, self._recognized_scalars) or other is NaT:
other = self._scalar_type(other)
self._check_compatible_with(other)

other_i8 = self._unbox_scalar(other)

result = op(self.view("i8"), other_i8)
if isna(other):
result.fill(nat_result)

elif not is_list_like(other):
return invalid_comparison(self, other, op)

elif len(other) != len(self):
raise ValueError("Lengths must match")

else:
if isinstance(other, list):
# TODO: could use pd.Index to do inference?
other = np.array(other)

if not isinstance(other, (np.ndarray, type(self))):
return invalid_comparison(self, other, op)

if is_object_dtype(other):
# We have to use comp_method_OBJECT_ARRAY instead of numpy
# comparison otherwise it would fail to raise when
# comparing tz-aware and tz-naive
with np.errstate(all="ignore"):
result = ops.comp_method_OBJECT_ARRAY(
op, self.astype(object), other
)
o_mask = isna(other)

elif not type(self)._is_recognized_dtype(other.dtype):
return invalid_comparison(self, other, op)

else:
# For PeriodDType this casting is unnecessary
other = type(self)._from_sequence(other)
self._check_compatible_with(other)

result = op(self.view("i8"), other.view("i8"))
o_mask = other._isnan

if o_mask.any():
result[o_mask] = nat_result

if self._hasnans:
result[self._isnan] = nat_result

return result

return set_function_name(wrapper, opname, cls)


class AttributesMixin:
_data: np.ndarray

Expand Down Expand Up @@ -934,6 +1010,7 @@ def _is_unique(self):

# ------------------------------------------------------------------
# Arithmetic Methods
_create_comparison_method = classmethod(_datetimelike_array_cmp)

# pow is invalid for all three subclasses; TimedeltaArray will override
# the multiplication and division ops
Expand Down Expand Up @@ -1485,6 +1562,8 @@ def mean(self, skipna=True):
return self._box_func(result)


DatetimeLikeArrayMixin._add_comparison_ops()

# -------------------------------------------------------------------
# Shared Constructor Helpers

Expand Down
87 changes: 1 addition & 86 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
timezones,
tzconversion,
)
import pandas.compat as compat
from pandas.errors import PerformanceWarning

from pandas.core.dtypes.common import (
Expand All @@ -32,7 +31,6 @@
is_dtype_equal,
is_extension_array_dtype,
is_float_dtype,
is_list_like,
is_object_dtype,
is_period_dtype,
is_string_dtype,
Expand All @@ -43,13 +41,10 @@
from pandas.core.dtypes.generic import ABCIndexClass, ABCPandasArray, ABCSeries
from pandas.core.dtypes.missing import isna

from pandas.core import ops
from pandas.core.algorithms import checked_add_with_arr
from pandas.core.arrays import datetimelike as dtl
from pandas.core.arrays._ranges import generate_regular_range
import pandas.core.common as com
from pandas.core.ops.common import unpack_zerodim_and_defer
from pandas.core.ops.invalid import invalid_comparison

from pandas.tseries.frequencies import get_period_alias, to_offset
from pandas.tseries.offsets import Day, Tick
Expand Down Expand Up @@ -131,81 +126,6 @@ def f(self):
return property(f)


def _dt_array_cmp(cls, op):
"""
Wrap comparison operations to convert datetime-like to datetime64
"""
opname = f"__{op.__name__}__"
nat_result = opname == "__ne__"

@unpack_zerodim_and_defer(opname)
def wrapper(self, other):

if isinstance(other, str):
try:
# GH#18435 strings get a pass from tzawareness compat
other = self._scalar_from_string(other)
except ValueError:
# string that cannot be parsed to Timestamp
return invalid_comparison(self, other, op)

if isinstance(other, self._recognized_scalars) or other is NaT:
other = self._scalar_type(other)
self._assert_tzawareness_compat(other)

other_i8 = other.value

result = op(self.view("i8"), other_i8)
if isna(other):
result.fill(nat_result)

elif not is_list_like(other):
return invalid_comparison(self, other, op)

elif len(other) != len(self):
raise ValueError("Lengths must match")

else:
if isinstance(other, list):
other = np.array(other)

if not isinstance(other, (np.ndarray, cls)):
# Following Timestamp convention, __eq__ is all-False
# and __ne__ is all True, others raise TypeError.
return invalid_comparison(self, other, op)

if is_object_dtype(other):
# We have to use comp_method_OBJECT_ARRAY instead of numpy
# comparison otherwise it would fail to raise when
# comparing tz-aware and tz-naive
with np.errstate(all="ignore"):
result = ops.comp_method_OBJECT_ARRAY(
op, self.astype(object), other
)
o_mask = isna(other)

elif not cls._is_recognized_dtype(other.dtype):
# e.g. is_timedelta64_dtype(other)
return invalid_comparison(self, other, op)

else:
self._assert_tzawareness_compat(other)
other = type(self)._from_sequence(other)

result = op(self.view("i8"), other.view("i8"))
o_mask = other._isnan

if o_mask.any():
result[o_mask] = nat_result

if self._hasnans:
result[self._isnan] = nat_result

return result

return compat.set_function_name(wrapper, opname, cls)


class DatetimeArray(dtl.DatetimeLikeArrayMixin, dtl.TimelikeOps, dtl.DatelikeOps):
"""
Pandas ExtensionArray for tz-naive or tz-aware datetime data.
Expand Down Expand Up @@ -324,7 +244,7 @@ def __init__(self, values, dtype=_NS_DTYPE, freq=None, copy=False):
raise TypeError(msg)
elif values.tz:
dtype = values.dtype
# freq = validate_values_freq(values, freq)

if freq is None:
freq = values.freq
values = values._data
Expand Down Expand Up @@ -714,8 +634,6 @@ def _format_native_types(self, na_rep="NaT", date_format=None, **kwargs):
# -----------------------------------------------------------------
# Comparison Methods

_create_comparison_method = classmethod(_dt_array_cmp)

def _has_same_tz(self, other):
zzone = self._timezone

Expand Down Expand Up @@ -1767,9 +1685,6 @@ def to_julian_date(self):
)


DatetimeArray._add_comparison_ops()


# -------------------------------------------------------------------
# Constructor Helpers

Expand Down
81 changes: 0 additions & 81 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,13 @@
period_asfreq_arr,
)
from pandas._libs.tslibs.timedeltas import Timedelta, delta_to_nanoseconds
import pandas.compat as compat
from pandas.util._decorators import cache_readonly

from pandas.core.dtypes.common import (
_TD_DTYPE,
ensure_object,
is_datetime64_dtype,
is_float_dtype,
is_list_like,
is_object_dtype,
is_period_dtype,
pandas_dtype,
)
Expand All @@ -42,12 +39,9 @@
)
from pandas.core.dtypes.missing import isna, notna

from pandas.core import ops
import pandas.core.algorithms as algos
from pandas.core.arrays import datetimelike as dtl
import pandas.core.common as com
from pandas.core.ops.common import unpack_zerodim_and_defer
from pandas.core.ops.invalid import invalid_comparison

from pandas.tseries import frequencies
from pandas.tseries.offsets import DateOffset, Tick, _delta_to_tick
Expand All @@ -64,77 +58,6 @@ def f(self):
return property(f)


def _period_array_cmp(cls, op):
"""
Wrap comparison operations to convert Period-like to PeriodDtype
"""
opname = f"__{op.__name__}__"
nat_result = opname == "__ne__"

@unpack_zerodim_and_defer(opname)
def wrapper(self, other):

if isinstance(other, str):
try:
other = self._scalar_from_string(other)
except ValueError:
# string that can't be parsed as Period
return invalid_comparison(self, other, op)

if isinstance(other, self._recognized_scalars) or other is NaT:
other = self._scalar_type(other)
self._check_compatible_with(other)

other_i8 = self._unbox_scalar(other)

result = op(self.view("i8"), other_i8)
if isna(other):
result.fill(nat_result)

elif not is_list_like(other):
return invalid_comparison(self, other, op)

elif len(other) != len(self):
raise ValueError("Lengths must match")

else:
if isinstance(other, list):
# TODO: could use pd.Index to do inference?
other = np.array(other)

if not isinstance(other, (np.ndarray, cls)):
return invalid_comparison(self, other, op)

if is_object_dtype(other):
with np.errstate(all="ignore"):
result = ops.comp_method_OBJECT_ARRAY(
op, self.astype(object), other
)
o_mask = isna(other)

elif not cls._is_recognized_dtype(other.dtype):
# e.g. is_timedelta64_dtype(other)
return invalid_comparison(self, other, op)

else:
assert isinstance(other, cls), type(other)

self._check_compatible_with(other)

result = op(self.view("i8"), other.view("i8"))
o_mask = other._isnan

if o_mask.any():
result[o_mask] = nat_result

if self._hasnans:
result[self._isnan] = nat_result

return result

return compat.set_function_name(wrapper, opname, cls)


class PeriodArray(dtl.DatetimeLikeArrayMixin, dtl.DatelikeOps):
"""
Pandas ExtensionArray for storing Period data.
Expand Down Expand Up @@ -639,7 +562,6 @@ def astype(self, dtype, copy=True):

# ------------------------------------------------------------------
# Arithmetic Methods
_create_comparison_method = classmethod(_period_array_cmp)

def _sub_datelike(self, other):
assert other is not NaT
Expand Down Expand Up @@ -810,9 +732,6 @@ def _check_timedeltalike_freq_compat(self, other):
raise raise_on_incompatible(self, other)


PeriodArray._add_comparison_ops()


def raise_on_incompatible(left, right):
"""
Helper function to render a consistent error message when raising
Expand Down
Loading