diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index fb8eecdaa275e..326597763d1bc 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -585,7 +585,7 @@ Reshaping - Bug in :meth:`DataFrame.append` returning incorrect dtypes with combinations of ``datetime64`` and ``timedelta64`` dtypes (:issue:`39574`) - Bug in :meth:`DataFrame.pivot_table` returning a ``MultiIndex`` for a single value when operating on and empty ``DataFrame`` (:issue:`13483`) - Allow :class:`Index` to be passed to the :func:`numpy.all` function (:issue:`40180`) -- +- Bug in :meth:`DataFrame.stack` not preserving ``CategoricalDtype`` in a ``MultiIndex`` (:issue:`36991`) Sparse ^^^^^^ diff --git a/pandas/core/reshape/reshape.py b/pandas/core/reshape/reshape.py index 271bb2ca8dd75..ff6ba3f8f4164 100644 --- a/pandas/core/reshape/reshape.py +++ b/pandas/core/reshape/reshape.py @@ -600,6 +600,33 @@ def stack_multiple(frame, level, dropna=True): return result +def _stack_multi_column_index(columns: MultiIndex) -> MultiIndex: + """Creates a MultiIndex from the first N-1 levels of this MultiIndex.""" + if len(columns.levels) <= 2: + return columns.levels[0]._rename(name=columns.names[0]) + + levs = [ + [lev[c] if c >= 0 else None for c in codes] + for lev, codes in zip(columns.levels[:-1], columns.codes[:-1]) + ] + + # Remove duplicate tuples in the MultiIndex. + tuples = zip(*levs) + unique_tuples = (key for key, _ in itertools.groupby(tuples)) + new_levs = zip(*unique_tuples) + + # The dtype of each level must be explicitly set to avoid inferring the wrong type. + # See GH-36991. + return MultiIndex.from_arrays( + [ + # Not all indices can accept None values. + Index(new_lev, dtype=lev.dtype) if None not in new_lev else new_lev + for new_lev, lev in zip(new_levs, columns.levels) + ], + names=columns.names[:-1], + ) + + def _stack_multi_columns(frame, level_num=-1, dropna=True): def _convert_level_number(level_num, columns): """ @@ -634,20 +661,7 @@ def _convert_level_number(level_num, columns): level_to_sort = _convert_level_number(0, this.columns) this = this.sort_index(level=level_to_sort, axis=1) - # tuple list excluding level for grouping columns - if len(frame.columns.levels) > 2: - levs = [] - for lev, level_codes in zip(this.columns.levels[:-1], this.columns.codes[:-1]): - if -1 in level_codes: - lev = np.append(lev, None) - levs.append(np.take(lev, level_codes)) - tuples = list(zip(*levs)) - unique_groups = [key for key, _ in itertools.groupby(tuples)] - new_names = this.columns.names[:-1] - new_columns = MultiIndex.from_tuples(unique_groups, names=new_names) - else: - new_columns = this.columns.levels[0]._rename(name=this.columns.names[0]) - unique_groups = new_columns + new_columns = _stack_multi_column_index(this.columns) # time to ravel the values new_data = {} @@ -658,7 +672,7 @@ def _convert_level_number(level_num, columns): level_vals_used = np.take(level_vals_nan, level_codes) levsize = len(level_codes) drop_cols = [] - for key in unique_groups: + for key in new_columns: try: loc = this.columns.get_loc(key) except KeyError: diff --git a/pandas/tests/frame/test_stack_unstack.py b/pandas/tests/frame/test_stack_unstack.py index fd23ea3a7621c..4082f21254e52 100644 --- a/pandas/tests/frame/test_stack_unstack.py +++ b/pandas/tests/frame/test_stack_unstack.py @@ -1065,6 +1065,27 @@ def test_stack_preserve_categorical_dtype(self, ordered, labels): tm.assert_series_equal(result, expected) + @pytest.mark.parametrize("ordered", [False, True]) + @pytest.mark.parametrize( + "labels,data", + [ + (list("xyz"), [10, 11, 12, 13, 14, 15]), + (list("zyx"), [14, 15, 12, 13, 10, 11]), + ], + ) + def test_stack_multi_preserve_categorical_dtype(self, ordered, labels, data): + # GH-36991 + cidx = pd.CategoricalIndex(labels, categories=sorted(labels), ordered=ordered) + cidx2 = pd.CategoricalIndex(["u", "v"], ordered=ordered) + midx = MultiIndex.from_product([cidx, cidx2]) + df = DataFrame([sorted(data)], columns=midx) + result = df.stack([0, 1]) + + s_cidx = pd.CategoricalIndex(sorted(labels), ordered=ordered) + expected = Series(data, index=MultiIndex.from_product([[0], s_cidx, cidx2])) + + tm.assert_series_equal(result, expected) + def test_stack_preserve_categorical_dtype_values(self): # GH-23077 cat = pd.Categorical(["a", "a", "b", "c"])