diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 9003c163..6304bf21 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -31,6 +31,7 @@ import matplotlib.patches as mpatch from functools import partial from collections import OrderedDict +from typing import List, Optional # Configure logging @@ -3826,3 +3827,92 @@ def present_summary_as_figure(summary_dict: dict) -> go.Figure: title="Data Summary" ) return fig + +# single cell quality control metrics violin plot +def plot_qc_metrics( + adata, + stat_columns_list: Optional[List[str]] = None, + annotation=None, + log=False, + size=1, + table=None, + **kwargs +): + """ + Generate violin plots for quality control metrics from an AnnData object. + + Parameters + ---------- + adata : AnnData object + stat_columns_list (list): List of column names to compute statistics for. + If None, defaults to ['nFeature', 'nCount', 'percent.mt']. + annotation : str or None, optional + Column name in adata.obs to group the data by (default: None). + log : bool, optional + Whether to log-transform the data (default: False). + size : float, optional + Size of the points in the violin plot (default: 1). + **kwargs : dict + Additional keyword arguments are passed to the underlying matplotlib + plotting functions and Scanpy plotting utilities. This allows customization + of plot appearance, such as axis labels, colors, figure size, + and other matplotlib options. + + Returns + ------- + dict + If annotation is None, returns a dictionary with keys + 'figure' and 'axes' for the whole dataset. + If annotation is provided, returns a dictionary mapping each group + to its own {'figure', 'axes'} dict for the subsetted AnnData. + """ + + # if not provided select default stat columns + if stat_columns_list is None: + stat_columns_list = ['nFeature', 'nCount', 'percent.mt'] + + # Check that required columns exist in adata.obs + check_annotation( + adata, + annotations=stat_columns_list, + should_exist=True) + + if annotation is not None: + check_annotation(adata, annotations=annotation) + results = {} + for group in adata.obs[annotation].unique(): + adata_subset = adata[adata.obs[annotation] == group] + violin_plot = sc.pl.violin( + adata_subset, + stat_columns_list, + size=size, + groupby=None, + log=log, + jitter=0.4, + multi_panel=True, + show=False, + use_raw=False, + **kwargs + ) + results[group] = { + "figure": violin_plot.figure, + "axes": violin_plot.axes + } + return results + else: + violin_plot = sc.pl.violin( + adata, + stat_columns_list, + size=size, + groupby=None, + log=log, + jitter=0.4, + multi_panel=True, + show=False, + use_raw=False, + **kwargs + ) + return { + "figure": violin_plot.figure, + "axes": violin_plot.axes + } \ No newline at end of file diff --git a/tests/test_visualization/test_plot_qc_metrics.py b/tests/test_visualization/test_plot_qc_metrics.py new file mode 100644 index 00000000..ed5b3860 --- /dev/null +++ b/tests/test_visualization/test_plot_qc_metrics.py @@ -0,0 +1,58 @@ +import unittest +from unittest import result +import numpy as np +import pandas as pd +import scanpy as sc +from anndata import AnnData +from matplotlib.figure import Figure +from matplotlib.axes import Axes +from spac.visualization import plot_qc_metrics +import numpy as np + +class TestPlotQCMetrics(unittest.TestCase): + @classmethod + def setUpClass(cls): + np.random.seed(42) + + def create_test_adata(self): + X = np.random.rand(10, 3) + obs = pd.DataFrame({ + "nCount": np.random.randint(100, 1000, 10), + "nFeature": np.random.randint(10, 100, 10), + "percent.mt": np.random.rand(10) * 10, + "group": ["A", "B"] * 5 + }) + var = pd.DataFrame(index=["gene1", "gene2", "gene3"]) + adata = AnnData(X=X, obs=obs, var=var) + return adata + + def test_plot_qc_metrics_returns_figure_and_axes(self): + adata = self.create_test_adata() + result = plot_qc_metrics(adata) + self.assertIsInstance(result, dict) + self.assertIn("figure", result) + self.assertIn("axes", result) + self.assertIsInstance(result["figure"], Figure) + # axes can be a numpy array or a single Axes + axes = result["axes"] + self.assertTrue(isinstance(axes, (np.ndarray, Axes))) + + def test_plot_qc_metrics_with_annotation_column(self): + adata = self.create_test_adata() + result = plot_qc_metrics(adata, annotation="group") + self.assertIsInstance(result["A"]["figure"], Figure) + print(result["A"]["axes"]) + print(type(result["A"]["axes"])) + self.assertTrue( + isinstance(result["A"]["axes"], Axes) or + all(isinstance(ax, Axes) for ax in result["A"]["axes"].flat) + ) + + def test_plot_qc_metrics_with_log(self): + adata = self.create_test_adata() + result = plot_qc_metrics(adata, log=True) + self.assertIsInstance(result["figure"], Figure) + self.assertTrue(isinstance(result["axes"], (np.ndarray, Axes))) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file