Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bindcurve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
plot_curves,
plot_fits,
plot_observations,
plot_residuals,
)
from bindcurve.results import FitMetrics, FitResult, FitResults, ParameterEstimate

Expand Down Expand Up @@ -78,4 +79,5 @@
"plot_curves",
"plot_fits",
"plot_observations",
"plot_residuals",
]
2 changes: 2 additions & 0 deletions bindcurve/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
plot_curves,
plot_fits,
plot_observations,
plot_residuals,
)

__all__ = [
Expand All @@ -18,4 +19,5 @@
"plot_curves",
"plot_fits",
"plot_observations",
"plot_residuals",
]
92 changes: 92 additions & 0 deletions bindcurve/plotting/curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,29 @@ def _fit_confidence_band(
return y, y - band, y + band


def _residual_table_for_fit(
table: pd.DataFrame,
fit: FitResult,
*,
aggregate: bool,
) -> pd.DataFrame:
fit_table = table
if fit.experiment_id is not None:
fit_table = table[table["experiment_id"].astype(str) == str(fit.experiment_id)]
Comment on lines +321 to +322

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle compound-summary fits when filtering residual rows

plot_residuals() drops all points for fits produced with FitSettings(strategy="per_compound_summary") because those fits have experiment_id="compound_summary", but raw observation tables only contain real experiment IDs (e.g., exp1, exp2). This filter makes fit_table empty, so the loop silently skips plotting residuals and users get an empty diagnostic despite having a successful fit. Please special-case summary fits (or any fit ID not present in the table) so residuals are computed from the aggregated compound-level observations instead of filtering everything out.

Useful? React with 👍 / 👎.


if fit_table.empty:
return pd.DataFrame(columns=["concentration", "response", "predicted", "residual"])

if aggregate:
fit_table = _aggregate_observations(fit_table, by_experiment=False)

plotted = fit_table.copy()
predicted = np.asarray(_evaluate_fit(fit, plotted["concentration"].to_numpy()), dtype=float)
plotted["predicted"] = predicted
plotted["residual"] = plotted["response"].to_numpy(dtype=float) - predicted
return plotted


def plot_fits(
data: DoseResponseData,
results: FitResults,
Expand Down Expand Up @@ -415,6 +438,75 @@ def plot_confidence_bands(
return ax


def plot_residuals(
data: DoseResponseData,
results: FitResults,
*,
compound_id: str | None = None,
ax: Axes | None = None,
experiments: Iterable[str] | None = None,
aggregate: bool = True,
xscale: XScale = "log",
zero_line: bool = True,
label: str | None = None,
zero_line_kwargs: dict | None = None,
**scatter_kwargs,
) -> Axes:
"""Plot fit residuals against concentration on an existing axes.

Residuals are computed as ``observed - predicted``. By default, technical
replicates are aggregated in the same way as fitted observations.
"""
ax = _get_axes(ax)
resolved_compound_id = _resolve_compound_id(data, compound_id)
compound = data.select_compound(resolved_compound_id)
table = _filter_experiments(compound.table, experiments)
fits = _matching_fits(
results,
compound_id=resolved_compound_id,
experiments=experiments,
)

if table.empty:
raise ValueError("No observations remain after filtering.")

default_scatter_kwargs = {"marker": "o"}
default_scatter_kwargs.update(scatter_kwargs)

for fit in fits:
residuals = _residual_table_for_fit(table, fit, aggregate=aggregate)
if residuals.empty:
continue

residual_label = label
if residual_label is None:
residual_label = str(fit.experiment_id or fit.model_name)

ax.scatter(
residuals["concentration"],
residuals["residual"],
label=residual_label,
**default_scatter_kwargs,
)

if zero_line:
default_zero_line_kwargs = {"linestyle": "--", "linewidth": 1.0, "alpha": 0.7}
default_zero_line_kwargs.update(zero_line_kwargs or {})
ax.axhline(0.0, **default_zero_line_kwargs)

if xscale is not None:
ax.set_xscale(xscale)
concentration_label = "concentration"
if data.concentration_unit is not None:
concentration_label = f"concentration ({data.concentration_unit})"
response_label = "residual"
if data.response_unit is not None:
response_label = f"residual ({data.response_unit})"
ax.set_xlabel(concentration_label)
ax.set_ylabel(response_label)
return ax


def plot_asymptotes(
data: DoseResponseData,
results: FitResults,
Expand Down
58 changes: 58 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,61 @@ def test_plot_confidence_bands_requires_covariance_matrix():
experiments=["exp1"],
)
plt.close(fig)


def test_plot_residuals_draws_aggregated_residuals_and_zero_line():
data = make_data()
results = make_results(data)
fig, ax = plt.subplots()

returned_ax = bc.plot_residuals(data, results, ax=ax, experiments=["exp1"])

assert returned_ax is ax
assert len(ax.collections) == 1
assert len(ax.lines) == 1
assert ax.lines[0].get_ydata()[0] == 0.0
assert ax.get_xlabel() == "concentration (uM)"
assert ax.get_ylabel() == "residual (percent)"
assert ax.get_xscale() == "log"
plt.close(fig)


def test_plot_residuals_can_plot_raw_replicate_residuals():
data = make_data()
results = make_results(data)
fig, ax = plt.subplots()

bc.plot_residuals(data, results, ax=ax, experiments=["exp1"], aggregate=False)

offsets = ax.collections[0].get_offsets()
assert len(offsets) == 24
plt.close(fig)


def test_plot_residuals_can_disable_zero_line_and_use_linear_xscale():
data = make_data()
results = make_results(data)
fig, ax = plt.subplots()

bc.plot_residuals(
data,
results,
ax=ax,
experiments=["exp1"],
zero_line=False,
xscale="linear",
)

assert len(ax.lines) == 0
assert ax.get_xscale() == "linear"
plt.close(fig)


def test_plot_residuals_requires_matching_successful_fits():
data = make_data()
results = make_results(data)
fig, ax = plt.subplots()

with pytest.raises(ValueError, match="No successful fits"):
bc.plot_residuals(data, results, ax=ax, experiments=["missing"])
plt.close(fig)
1 change: 1 addition & 0 deletions tests/test_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_public_api_exports_new_objects():
"plot_curves",
"plot_fits",
"plot_observations",
"plot_residuals",
}

assert expected <= set(bc.__all__)
Expand Down
Loading