Source code for energytrackr.plot.builtin_plots.ecdf_comparison

"""ECDFComparison using BasePlot and mixins for cleaner composition."""

from __future__ import annotations

from collections.abc import Sequence
from pathlib import Path
from typing import Any

import numpy as np
import pandas as pd
from bokeh.models import ColumnDataSource, Range1d
from bokeh.palettes import Category10
from bokeh.plotting import figure

from energytrackr.plot.builtin_plots.mixins import (
    ComparisonBase,
    get_labels_and_dists,
    initial_commits,
)
from energytrackr.plot.builtin_plots.registry import register_plot
from energytrackr.plot.core.context import Context


[docs] @register_plot class ECDFComparison(ComparisonBase): """Interactive ECDF comparison between two selected commits.""" def __init__(self) -> None: """Initialize the ECDFComparison plot.""" self._full_ecdf: ColumnDataSource | None = None self._labels: Sequence[str] | None = None self._init_indices: tuple[int, int] | None = None self._ds1: ColumnDataSource | None = None self._ds2: ColumnDataSource | None = None def _make_sources(self, ctx: Context) -> dict[str, Any]: # compute full ECDF data labels, dists = get_labels_and_dists(ctx) src_full, (_, _) = self._compute_ecdf(dists) # initial indices (oldest vs newest) i1, i2 = initial_commits(labels) # prepare per-commit ecdf data for JS and callbacks self._full_ecdf = src_full self._labels = labels # store initial indices for plot drawing self._init_indices = (int(i1), int(i2)) return {"full_ecdf": src_full, "labels": labels} def _make_figure(self, ctx: Context) -> figure: # derive x-range from distributions dists = ctx.artefacts["distributions"] all_vals = np.concatenate([np.asarray(arr, float) for arr in dists]) x_min, x_max = float(all_vals.min()), float(all_vals.max()) return figure( title=self._title(ctx), x_range=Range1d(start=x_min, end=x_max), sizing_mode="stretch_width", tools="pan,box_zoom,reset,save,wheel_zoom,hover", toolbar_location="above", x_axis_label=f"{ctx.energy_fields[0]} (J)", y_axis_label="ECDF", ) def _draw_glyphs(self, fig: figure, sources: dict[str, ColumnDataSource], ctx: Context) -> None: # noqa: ARG002 # unpack src_full = self._full_ecdf assert self._init_indices is not None i1, i2 = self._init_indices # build step data sources assert src_full is not None raw1 = self._make_ecdf_ds(src_full, i1) raw2 = self._make_ecdf_ds(src_full, i2) ds1 = ColumnDataSource(pd.DataFrame({"x": raw1["x"], "y": raw1["y"]})) ds2 = ColumnDataSource(pd.DataFrame({"x": raw2["x"], "y": raw2["y"]})) self._ds1 = ds1 self._ds2 = ds2 # render ECDF steps palette = Category10[3] fig.step("x", "y", source=ds1, mode="after", line_width=2, line_color=palette[0], legend_label="Commit A") fig.step("x", "y", source=ds2, mode="after", line_width=2, line_color=palette[1], legend_label="Commit B") def _configure(self, fig: figure, ctx: Context) -> None: super()._configure(fig, ctx) # legend placement if fig.legend: for legend in fig.legend: legend.location = "bottom_right" def _callback_js_path(self) -> Path: # noqa: PLR6301 return Path(__file__).parent / "static" / "ecdf_comparison.js" def _callback_args(self, fig: figure, ctx: Context) -> dict[str, Any]: # noqa: ARG002 return { "full_ecdf": self._full_ecdf, "src1": self._ds1, "src2": self._ds2, "plot": fig, "labels": self._labels, } def _hover_tooltips(self, ctx: Context) -> list[tuple[str, str]]: # noqa: ARG002, PLR6301 return [ ("Value", "@x{0.00} J"), ("ECDF", "@y{0.00}"), ] def _title(self, ctx: Context) -> str: # noqa: PLR6301 return f"Empirical CDF: {ctx.energy_fields[0]}" def _key(self, ctx: Context) -> str: # noqa: ARG002, PLR6301 return "ECDF Comparison" @staticmethod def _compute_ecdf(dists: list[np.ndarray]) -> tuple[ColumnDataSource, tuple[float, float]]: xs_list: list[list[float]] = [] ys_list: list[list[float]] = [] for arr in dists: data = np.sort(np.asarray(arr, float)) xs_list.append(data.tolist()) ys_list.append((np.arange(1, len(data) + 1) / len(data)).tolist()) all_vals = np.concatenate([np.asarray(arr, float) for arr in dists]) vmin, vmax = float(all_vals.min()), float(all_vals.max()) return ColumnDataSource({"ecdf_x": xs_list, "ecdf_y": ys_list}), (vmin, vmax) @staticmethod def _make_ecdf_ds(full_src: ColumnDataSource, idx: int) -> dict[str, list[float]]: return { "x": full_src.data["ecdf_x"][idx], "y": full_src.data["ecdf_y"][idx], }