Source code for energytrackr.plot.builtin_plots.boxplot_comparison

"""This module implements the BoxplotComparison plot for energytrackr.

BoxplotComparison using BasePlot and composable mixins for cleaner architecture,
with notches, median-connecting line, and full/raw data sources restored.
"""

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Any

import numpy as np
from bokeh.models import ColumnDataSource, FactorRange, Range1d, Whisker
from bokeh.palettes import Category10
from bokeh.plotting import figure
from bokeh.transform import jitter

from energytrackr.plot.builtin_plots.mixins import (
    ComparisonBase,
    get_labels_and_dists,
    initial_commits,
)
from energytrackr.plot.builtin_plots.registry import register_plot
from energytrackr.plot.core.context import Context


[docs] @dataclass class SingleStats: """Statistical summary for a commit: quartiles, whiskers, and notch bounds.""" commit: str quartiles: tuple[float, float, float] # Q1, median, Q3 whiskers: tuple[float, float] # lower, upper notch: tuple[float, float] # lower, upper notch around the median
[docs] @register_plot class BoxplotComparison(ComparisonBase): """Interactive boxplot comparison for two selected commits. - quartiles & whiskers - notches around medians - dashed line connecting medians - full & raw data sources for dynamic JS callbacks - tilted x-axis labels for readability """ def __init__(self) -> None: """Initialize the BoxplotComparison plot.""" self._stats: list[SingleStats] | None = None self._sources: dict[str, ColumnDataSource] | None = None def _make_figure(self, ctx: Context) -> figure: # noqa: PLR6301 labels, _ = get_labels_and_dists(ctx) init1, init2 = initial_commits(labels) fig = figure( x_range=FactorRange(*[init1, init2]), sizing_mode="stretch_width", tools="pan,box_zoom,reset,save,wheel_zoom,hover", toolbar_location="above", title=f"Distribution Boxplot: {ctx.energy_fields[0]}", x_axis_label="Commit (short hash)", y_axis_label=f"{ctx.energy_fields[0]} (J)", ) # tilt x-axis labels like the old version for axis in fig.xaxis: axis.major_label_orientation = 0.8 return fig def _compute_stats(self, labels: list[str], dists: list[np.ndarray]) -> list[SingleStats]: # noqa: PLR6301 # pylint: disable=too-many-locals stats: list[SingleStats] = [] for lbl, arr in zip(labels, dists, strict=False): a = np.sort(np.asarray(arr, float)) q1, med, q3 = np.percentile(a, [25, 50, 75]) iqr = q3 - q1 lower = max(a.min(), q1 - 1.5 * iqr) upper = min(a.max(), q3 + 1.5 * iqr) half_notch = 1.57 * iqr / np.sqrt(len(a)) notch_low, notch_high = med - half_notch, med + half_notch stats.append(SingleStats(lbl, (q1, med, q3), (lower, upper), (notch_low, notch_high))) return stats def _make_sources(self, ctx: Context) -> dict[str, Any]: # noqa: PLR0914 # pylint: disable=too-many-locals labels, dists = get_labels_and_dists(ctx) # compute stats for each commit, including notch stats = self._compute_stats(labels, dists) init1, init2 = initial_commits(labels) i1, i2 = labels.index(init1), labels.index(init2) vals1, vals2 = dists[i1], dists[i2] l1, u1 = stats[i1].whiskers l2, u2 = stats[i2].whiskers norm1 = [v for v in vals1 if l1 <= v <= u1] out1 = [v for v in vals1 if v < l1 or v > u1] norm2 = [v for v in vals2 if l2 <= v <= u2] out2 = [v for v in vals2 if v < l2 or v > u2] # build two new CDS's inlier_scatter = ColumnDataSource({ "commit": [init1] * len(norm1) + [init2] * len(norm2), "value": norm1 + norm2, }) outlier_scatter = ColumnDataSource({ "commit": [init1] * len(out1) + [init2] * len(out2), "value": out1 + out2, }) # full data source (all commits) full_src = ColumnDataSource({ "commit": [s.commit for s in stats], "q1": [s.quartiles[0] for s in stats], "median": [s.quartiles[1] for s in stats], "q3": [s.quartiles[2] for s in stats], "lower": [s.whiskers[0] for s in stats], "upper": [s.whiskers[1] for s in stats], "n_low": [s.notch[0] for s in stats], "n_high": [s.notch[1] for s in stats], }) # raw distributions source raw_src = ColumnDataSource({ "commit": labels, "values": dists, }) # color palette for the two boxes palette = Category10[3] colors = palette[:2] # box + notch + whisker source for initial pair box_src = ColumnDataSource({ "commit": [init1, init2], "q1": [stats[i1].quartiles[0], stats[i2].quartiles[0]], "median": [stats[i1].quartiles[1], stats[i2].quartiles[1]], "q3": [stats[i1].quartiles[2], stats[i2].quartiles[2]], "lower": [stats[i1].whiskers[0], stats[i2].whiskers[0]], "upper": [stats[i1].whiskers[1], stats[i2].whiskers[1]], "n_low": [stats[i1].notch[0], stats[i2].notch[0]], "n_high": [stats[i1].notch[1], stats[i2].notch[1]], "color": colors, }) # scatter sources for raw points scat1 = ColumnDataSource({"commit": [init1] * len(vals1), "value": vals1}) scat2 = ColumnDataSource({"commit": [init2] * len(vals2), "value": vals2}) # line source connecting medians line_src = ColumnDataSource({ "x": [init1, init2], "y": [stats[i1].quartiles[1], stats[i2].quartiles[1]], }) self._stats = stats self._sources = { "full": full_src, "raw": raw_src, "box": box_src, "scatter1": scat1, "scatter2": scat2, "line": line_src, "inlier_scatter": inlier_scatter, "outlier_scatter": outlier_scatter, } return self._sources def _draw_glyphs(self, fig: figure, sources: dict[str, ColumnDataSource], ctx: Context) -> None: # noqa: ARG002 box_src = sources["box"] # notch (thin bar around median) fig.vbar( x="commit", bottom="n_low", top="n_high", width=0.3, source=box_src, fill_color="color", fill_alpha=0.5, line_color="red", legend_label="Notch", ) # box between Q1 and Q3 fig.vbar( x="commit", bottom="q1", top="q3", width=0.6, source=box_src, fill_color="color", fill_alpha=0.3, line_color="black", legend_label="IQR box", ) # whiskers whisk = Whisker(source=box_src, base="commit", lower="lower", upper="upper") upper_head = whisk.upper_head assert upper_head is not None upper_head.size = 8 lower_head = whisk.lower_head assert lower_head is not None lower_head.size = 8 fig.add_layout(whisk) # median ticks fig.rect( x="commit", y="median", width=0.6, height=1, height_units="screen", source=box_src, color="red", legend_label="Median", ) # dashed line connecting the two medians fig.line( x="x", y="y", source=sources["line"], line_dash="dashed", line_color="firebrick", line_width=2, legend_label="Median trend", ) # draw inliers lightly fig.circle( x=jitter("commit", width=0.3, range=fig.x_range), y="value", source=sources["inlier_scatter"], radius=0.01, # small, subtle alpha=0.3, color="grey", legend_label="Inliers", ) # draw outliers in bold fig.circle( x=jitter("commit", width=0.3, range=fig.x_range), y="value", source=sources["outlier_scatter"], radius=0.02, # larger marker alpha=0.8, # more opaque fill_color="firebrick", line_color="black", line_width=1, legend_label="Outliers", ) # adjust y-range to include whiskers with margin assert self._stats is not None lows = [ws[0] for ws in (s.whiskers for s in self._stats)] highs = [ws[1] for ws in (s.whiskers for s in self._stats)] margin = (max(highs) - min(lows)) * 0.05 fig.y_range = Range1d(start=min(lows) - margin, end=max(highs) + margin) def _title(self, ctx: Context) -> str: # noqa: PLR6301 return f"Distribution Boxplot: {ctx.energy_fields[0]}" def _key(self, ctx: Context) -> str: # noqa: ARG002, PLR6301 return "Boxplot" def _callback_js_path(self) -> Path: # noqa: PLR6301 return Path(__file__).parent / "static" / "boxplot_comparison.js" def _callback_args(self, fig: figure, ctx: Context) -> dict[str, Any]: # noqa: ARG002 # pass full/raw + box/scatter/line sources into JS callback assert self._sources is not None return {**self._sources, "plot": fig} def _hover_tooltips(self, ctx: Context) -> list[tuple[str, str]]: # noqa: ARG002, PLR6301 return [ ("Commit", "@commit"), ("Q1", "@q1{0.00} J"), ("Median", "@median{0.00} J"), ("Q3", "@q3{0.00} J"), ("Lower", "@lower{0.00} J"), ("Upper", "@upper{0.00} J"), ]