From 77be5280b80a7a139f1b814bb21da0450319d8d4 Mon Sep 17 00:00:00 2001 From: Jerome Dockes Date: Mon, 1 Jul 2024 17:10:07 +0200 Subject: [PATCH] add percentages to bar plots --- src/skrubview/_plotting.py | 38 ++++++++++++++++++++++++------------- src/skrubview/_summarize.py | 2 +- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/skrubview/_plotting.py b/src/skrubview/_plotting.py index df344ca..9baf86b 100644 --- a/src/skrubview/_plotting.py +++ b/src/skrubview/_plotting.py @@ -6,20 +6,23 @@ from . import _utils # from matplotlib import colormaps, colors -# _PASTEL = list(map(colors.rgb2hex, colormaps.get_cmap("tab10").colors)) +# _TAB10 = list(map(colors.rgb2hex, colormaps.get_cmap("tab10").colors)) + +# sns.color_palette('muted').as_hex() _SEABORN = [ - "#4c72b0", - "#dd8452", - "#55a868", - "#c44e52", - "#8172b3", - "#937860", - "#da8bc3", - "#8c8c8c", - "#ccb974", - "#64b5cd", + "#4878d0", + "#ee854a", + "#6acc64", + "#d65f5f", + "#956cb4", + "#8c613c", + "#dc7ec0", + "#797979", + "#d5bb67", + "#82c6e2", ] + COLORS = _SEABORN COLOR_0 = COLORS[0] @@ -80,7 +83,7 @@ def line(x_col, y_col): return _serialize(fig) -def value_counts(value_counts, n_unique, color=COLOR_0): +def value_counts(value_counts, n_unique, n_rows, color=COLOR_0): values = [_utils.ellide_string_short(s) for s in value_counts.keys()][::-1] counts = list(value_counts.values())[::-1] if n_unique > len(value_counts): @@ -89,7 +92,16 @@ def value_counts(value_counts, n_unique, color=COLOR_0): title = None fig, ax = plt.subplots() _despine(ax) - ax.barh(list(map(str, range(len(values)))), counts, color=color) + rects = ax.barh(list(map(str, range(len(values)))), counts, color=color) + percent = [_utils.format_percent(c / n_rows) for c in counts] + large_percent = [ + f"{p: >6}" if c > counts[-1] / 2 else "" for (p, c) in zip(percent, counts) + ] + small_percent = [ + p if c <= counts[-1] / 2 else "" for (p, c) in zip(percent, counts) + ] + ax.bar_label(rects, large_percent, padding=-30, color="black", fontsize=8) + ax.bar_label(rects, small_percent, padding=5, color="black", fontsize=8) ax.set_yticks(ax.get_yticks()) ax.set_yticklabels(list(map(str, values))) if title is not None: diff --git a/src/skrubview/_summarize.py b/src/skrubview/_summarize.py index 98fe2c9..d82edf4 100644 --- a/src/skrubview/_summarize.py +++ b/src/skrubview/_summarize.py @@ -115,7 +115,7 @@ def _add_value_counts(summary, column, *, dataframe_summary, with_plots): summary["value_is_constant"] = False if with_plots: summary["value_counts_plot"] = _plotting.value_counts( - value_counts, n_unique, color=_plotting.COLORS[1] + value_counts, n_unique, dataframe_summary["n_rows"], color=_plotting.COLORS[1] )