Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions src/spac/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import matplotlib.patches as mpatch
from functools import partial
from collections import OrderedDict
from typing import List, Optional


# Configure logging
Expand Down Expand Up @@ -3826,3 +3827,92 @@ def present_summary_as_figure(summary_dict: dict) -> go.Figure:
title="Data Summary"
)
return fig

# single cell quality control metrics violin plot
def plot_qc_metrics(
adata,
stat_columns_list: Optional[List[str]] = None,
annotation=None,
log=False,
size=1,
table=None,
**kwargs
):
"""
Generate violin plots for quality control metrics from an AnnData object.

Parameters
----------
adata : AnnData object
stat_columns_list (list): List of column names to compute statistics for.
If None, defaults to ['nFeature', 'nCount', 'percent.mt'].
annotation : str or None, optional
Column name in adata.obs to group the data by (default: None).
log : bool, optional
Whether to log-transform the data (default: False).
size : float, optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abombin , would you like to check that the size is within the correct range?

Size of the points in the violin plot (default: 1).
**kwargs : dict
Additional keyword arguments are passed to the underlying matplotlib
plotting functions and Scanpy plotting utilities. This allows customization
of plot appearance, such as axis labels, colors, figure size,
and other matplotlib options.

Returns
-------
dict
If annotation is None, returns a dictionary with keys
'figure' and 'axes' for the whole dataset.
If annotation is provided, returns a dictionary mapping each group
to its own {'figure', 'axes'} dict for the subsetted AnnData.
"""

# if not provided select default stat columns
if stat_columns_list is None:
stat_columns_list = ['nFeature', 'nCount', 'percent.mt']

# Check that required columns exist in adata.obs
check_annotation(
adata,
annotations=stat_columns_list,
should_exist=True)

if annotation is not None:
check_annotation(adata, annotations=annotation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @abombin , you can combine that check with the check on line 3875 (just create one list with all the annotations)

results = {}
for group in adata.obs[annotation].unique():
adata_subset = adata[adata.obs[annotation] == group]
violin_plot = sc.pl.violin(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abombin , I suggest to use "partial" to define the call once, and call it in various part of the code (e.g. here and in line 3903). This way, if you ever change it, you change it in one place.

adata_subset,
stat_columns_list,
size=size,
groupby=None,
log=log,
jitter=0.4,
multi_panel=True,
show=False,
use_raw=False,
**kwargs
)
results[group] = {
"figure": violin_plot.figure,
"axes": violin_plot.axes
}
return results
else:
violin_plot = sc.pl.violin(
adata,
stat_columns_list,
size=size,
groupby=None,
log=log,
jitter=0.4,
multi_panel=True,
show=False,
use_raw=False,
**kwargs
)
return {
"figure": violin_plot.figure,
"axes": violin_plot.axes
}
58 changes: 58 additions & 0 deletions tests/test_visualization/test_plot_qc_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import unittest
from unittest import result
import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from spac.visualization import plot_qc_metrics
import numpy as np

class TestPlotQCMetrics(unittest.TestCase):
@classmethod
def setUpClass(cls):
np.random.seed(42)

def create_test_adata(self):
X = np.random.rand(10, 3)
obs = pd.DataFrame({
"nCount": np.random.randint(100, 1000, 10),
"nFeature": np.random.randint(10, 100, 10),
"percent.mt": np.random.rand(10) * 10,
"group": ["A", "B"] * 5
})
var = pd.DataFrame(index=["gene1", "gene2", "gene3"])
adata = AnnData(X=X, obs=obs, var=var)
return adata

def test_plot_qc_metrics_returns_figure_and_axes(self):
adata = self.create_test_adata()
result = plot_qc_metrics(adata)
self.assertIsInstance(result, dict)
self.assertIn("figure", result)
self.assertIn("axes", result)
self.assertIsInstance(result["figure"], Figure)
# axes can be a numpy array or a single Axes
axes = result["axes"]
self.assertTrue(isinstance(axes, (np.ndarray, Axes)))

def test_plot_qc_metrics_with_annotation_column(self):
adata = self.create_test_adata()
result = plot_qc_metrics(adata, annotation="group")
self.assertIsInstance(result["A"]["figure"], Figure)
print(result["A"]["axes"])
print(type(result["A"]["axes"]))
self.assertTrue(
isinstance(result["A"]["axes"], Axes) or
all(isinstance(ax, Axes) for ax in result["A"]["axes"].flat)
)

def test_plot_qc_metrics_with_log(self):
adata = self.create_test_adata()
result = plot_qc_metrics(adata, log=True)
self.assertIsInstance(result["figure"], Figure)
self.assertTrue(isinstance(result["axes"], (np.ndarray, Axes)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abombin , would you like to check other aspects of the figure other than type of what is being returned? Something to verify that the figure actually has a valid plot in it?


if __name__ == "__main__":
unittest.main()