Skip to content

Commit 65bf9cd

Browse files
authored
BUG: Fixed issue with bar plots not stacking correctly when 'stacked' and 'subplots' are used together (#61340)
* test case for subplot stacking * Removed overlooked print statement * Updated test to check other subplot in figure * Updated test cases to include more subplot stacking possibilities * removed savefig() left in test cases * Updated test cases to test more arrangements * Completed function fix (order of subplot input does not matter, need clarification if it matters) * appeasing the great pre-commit formatter * Updated whatsnew * Docstring adjustment * Moved self.subplot check to a seperate bool * Added ignore where mypy thinks self.subplots is a bool * Actually addressed mypy typing * Incorperated initial PR comments * Updated missing () after .all * Addressed more comments on PR * Updated '&' to 'and'
1 parent d79f7b0 commit 65bf9cd

File tree

3 files changed

+172
-1
lines changed

3 files changed

+172
-1
lines changed

‎doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,7 @@ Period
795795
Plotting
796796
^^^^^^^^
797797
- Bug in :meth:`.DataFrameGroupBy.boxplot` failed when there were multiple groupings (:issue:`14701`)
798+
- Bug in :meth:`DataFrame.plot.bar` when ``subplots`` and ``stacked=True`` are used in conjunction which causes incorrect stacking. (:issue:`61018`)
798799
- Bug in :meth:`DataFrame.plot.bar` with ``stacked=True`` where labels on stacked bars with zero-height segments were incorrectly positioned at the base instead of the label position of the previous segment (:issue:`59429`)
799800
- Bug in :meth:`DataFrame.plot.line` raising ``ValueError`` when set both color and a ``dict`` style (:issue:`59461`)
800801
- Bug in :meth:`DataFrame.plot` that causes a shift to the right when the frequency multiplier is greater than one. (:issue:`57587`)

‎pandas/plotting/_matplotlib/core.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -1928,6 +1928,20 @@ def _make_plot(self, fig: Figure) -> None:
19281928
K = self.nseries
19291929

19301930
data = self.data.fillna(0)
1931+
1932+
_stacked_subplots_ind: dict[int, int] = {}
1933+
_stacked_subplots_offsets = []
1934+
1935+
self.subplots: list[Any]
1936+
1937+
if bool(self.subplots) and self.stacked:
1938+
for i, sub_plot in enumerate(self.subplots):
1939+
if len(sub_plot) <= 1:
1940+
continue
1941+
for plot in sub_plot:
1942+
_stacked_subplots_ind[int(plot)] = i
1943+
_stacked_subplots_offsets.append([0, 0])
1944+
19311945
for i, (label, y) in enumerate(self._iter_data(data=data)):
19321946
ax = self._get_ax(i)
19331947
kwds = self.kwds.copy()
@@ -1953,7 +1967,28 @@ def _make_plot(self, fig: Figure) -> None:
19531967
start = start + self._start_base
19541968

19551969
kwds["align"] = self._align
1956-
if self.subplots:
1970+
1971+
if i in _stacked_subplots_ind:
1972+
offset_index = _stacked_subplots_ind[i]
1973+
pos_prior, neg_prior = _stacked_subplots_offsets[offset_index] # type:ignore[assignment]
1974+
mask = y >= 0
1975+
start = np.where(mask, pos_prior, neg_prior) + self._start_base
1976+
w = self.bar_width / 2
1977+
rect = self._plot(
1978+
ax,
1979+
self.ax_pos + w,
1980+
y,
1981+
self.bar_width,
1982+
start=start,
1983+
label=label,
1984+
log=self.log,
1985+
**kwds,
1986+
)
1987+
pos_new = pos_prior + np.where(mask, y, 0)
1988+
neg_new = neg_prior + np.where(mask, 0, y)
1989+
_stacked_subplots_offsets[offset_index] = [pos_new, neg_new]
1990+
1991+
elif self.subplots:
19571992
w = self.bar_width / 2
19581993
rect = self._plot(
19591994
ax,

‎pandas/tests/plotting/test_misc.py

+135
Original file line numberDiff line numberDiff line change
@@ -681,3 +681,138 @@ def test_bar_plt_xaxis_intervalrange(self):
681681
(a.get_text() == b.get_text())
682682
for a, b in zip(s.plot.bar().get_xticklabels(), expected)
683683
)
684+
685+
686+
@pytest.fixture
687+
def df_bar_data():
688+
return np.random.default_rng(3).integers(0, 100, 5)
689+
690+
691+
@pytest.fixture
692+
def df_bar_df(df_bar_data) -> DataFrame:
693+
df_bar_df = DataFrame(
694+
{
695+
"A": df_bar_data,
696+
"B": df_bar_data[::-1],
697+
"C": df_bar_data[0],
698+
"D": df_bar_data[-1],
699+
}
700+
)
701+
return df_bar_df
702+
703+
704+
def _df_bar_xyheight_from_ax_helper(df_bar_data, ax, subplot_division):
705+
subplot_data_df_list = []
706+
707+
# get xy and height of squares representing data, separated by subplots
708+
for i in range(len(subplot_division)):
709+
subplot_data = np.array(
710+
[
711+
(x.get_x(), x.get_y(), x.get_height())
712+
for x in ax[i].findobj(plt.Rectangle)
713+
if x.get_height() in df_bar_data
714+
]
715+
)
716+
subplot_data_df_list.append(
717+
DataFrame(data=subplot_data, columns=["x_coord", "y_coord", "height"])
718+
)
719+
720+
return subplot_data_df_list
721+
722+
723+
def _df_bar_subplot_checker(df_bar_data, df_bar_df, subplot_data_df, subplot_columns):
724+
subplot_sliced_by_source = [
725+
subplot_data_df.iloc[
726+
len(df_bar_data) * i : len(df_bar_data) * (i + 1)
727+
].reset_index()
728+
for i in range(len(subplot_columns))
729+
]
730+
expected_total_height = df_bar_df.loc[:, subplot_columns].sum(axis=1)
731+
732+
for i in range(len(subplot_columns)):
733+
sliced_df = subplot_sliced_by_source[i]
734+
if i == 0:
735+
# Checks that the bar chart starts y=0
736+
assert (sliced_df["y_coord"] == 0).all()
737+
height_iter = sliced_df["y_coord"].add(sliced_df["height"])
738+
else:
739+
height_iter = height_iter + sliced_df["height"]
740+
741+
if i + 1 == len(subplot_columns):
742+
# Checks final height matches what is expected
743+
tm.assert_series_equal(
744+
height_iter, expected_total_height, check_names=False, check_dtype=False
745+
)
746+
747+
else:
748+
# Checks each preceding bar ends where the next one starts
749+
next_start_coord = subplot_sliced_by_source[i + 1]["y_coord"]
750+
tm.assert_series_equal(
751+
height_iter, next_start_coord, check_names=False, check_dtype=False
752+
)
753+
754+
755+
# GH Issue 61018
756+
@pytest.mark.parametrize("columns_used", [["A", "B"], ["C", "D"], ["D", "A"]])
757+
def test_bar_1_subplot_1_double_stacked(df_bar_data, df_bar_df, columns_used):
758+
df_bar_df_trimmed = df_bar_df[columns_used]
759+
subplot_division = [columns_used]
760+
ax = df_bar_df_trimmed.plot(subplots=subplot_division, kind="bar", stacked=True)
761+
subplot_data_df_list = _df_bar_xyheight_from_ax_helper(
762+
df_bar_data, ax, subplot_division
763+
)
764+
for i in range(len(subplot_data_df_list)):
765+
_df_bar_subplot_checker(
766+
df_bar_data, df_bar_df_trimmed, subplot_data_df_list[i], subplot_division[i]
767+
)
768+
769+
770+
@pytest.mark.parametrize(
771+
"columns_used", [["A", "B", "C"], ["A", "C", "B"], ["D", "A", "C"]]
772+
)
773+
def test_bar_2_subplot_1_double_stacked(df_bar_data, df_bar_df, columns_used):
774+
df_bar_df_trimmed = df_bar_df[columns_used]
775+
subplot_division = [(columns_used[0], columns_used[1]), (columns_used[2],)]
776+
ax = df_bar_df_trimmed.plot(subplots=subplot_division, kind="bar", stacked=True)
777+
subplot_data_df_list = _df_bar_xyheight_from_ax_helper(
778+
df_bar_data, ax, subplot_division
779+
)
780+
for i in range(len(subplot_data_df_list)):
781+
_df_bar_subplot_checker(
782+
df_bar_data, df_bar_df_trimmed, subplot_data_df_list[i], subplot_division[i]
783+
)
784+
785+
786+
@pytest.mark.parametrize(
787+
"subplot_division",
788+
[
789+
[("A", "B"), ("C", "D")],
790+
[("A", "D"), ("C", "B")],
791+
[("B", "C"), ("D", "A")],
792+
[("B", "D"), ("C", "A")],
793+
],
794+
)
795+
def test_bar_2_subplot_2_double_stacked(df_bar_data, df_bar_df, subplot_division):
796+
ax = df_bar_df.plot(subplots=subplot_division, kind="bar", stacked=True)
797+
subplot_data_df_list = _df_bar_xyheight_from_ax_helper(
798+
df_bar_data, ax, subplot_division
799+
)
800+
for i in range(len(subplot_data_df_list)):
801+
_df_bar_subplot_checker(
802+
df_bar_data, df_bar_df, subplot_data_df_list[i], subplot_division[i]
803+
)
804+
805+
806+
@pytest.mark.parametrize(
807+
"subplot_division",
808+
[[("A", "B", "C")], [("A", "D", "B")], [("C", "A", "D")], [("D", "C", "A")]],
809+
)
810+
def test_bar_2_subplots_1_triple_stacked(df_bar_data, df_bar_df, subplot_division):
811+
ax = df_bar_df.plot(subplots=subplot_division, kind="bar", stacked=True)
812+
subplot_data_df_list = _df_bar_xyheight_from_ax_helper(
813+
df_bar_data, ax, subplot_division
814+
)
815+
for i in range(len(subplot_data_df_list)):
816+
_df_bar_subplot_checker(
817+
df_bar_data, df_bar_df, subplot_data_df_list[i], subplot_division[i]
818+
)

0 commit comments

Comments
 (0)