Skip to content

Commit eb9af30

Browse files
authored
Merge pull request #70 from FNLCR-DMAP/dev
Minor updates from NIDAP application
2 parents 35e2133 + 08c767b commit eb9af30

File tree

6 files changed

+95
-25
lines changed

6 files changed

+95
-25
lines changed

src/spac/transformations.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import scanpy.external as sce
66

77

8-
def phenograph_clustering(adata, features, layer, k=30):
8+
def phenograph_clustering(adata, features, layer=None, k=30):
99
"""
1010
Calculate automatic phenotypes using phenograph.
1111
@@ -39,8 +39,8 @@ def phenograph_clustering(adata, features, layer, k=30):
3939
not all(isinstance(feature, str) for feature in features)):
4040
raise TypeError("`features` must be a list of strings")
4141

42-
if layer not in adata.layers.keys():
43-
raise ValueError(f"`layer` not found in `adata.layers`. "
42+
if layer is not None and layer not in adata.layers.keys():
43+
raise ValueError(f"`{layer}` not found in `adata.layers`. "
4444
f"Available layers are {list(adata.layers.keys())}")
4545

4646
if not isinstance(k, int) or k <= 0:
@@ -50,7 +50,11 @@ def phenograph_clustering(adata, features, layer, k=30):
5050
raise ValueError("One or more of the `features` are not in "
5151
"`adata.var_names`")
5252

53-
phenograph_df = adata.to_df(layer=layer)[features]
53+
if layer is not None:
54+
phenograph_df = adata.to_df(layer=layer)[features]
55+
else:
56+
phenograph_df = adata.to_df()[features]
57+
5458
phenograph_out = sce.tl.phenograph(phenograph_df,
5559
clustering_algo="louvain",
5660
k=k)
@@ -245,8 +249,15 @@ def rename_observations(adata, src_observation, dest_observation, mappings):
245249
adata.obs[dest_observation] = (
246250
adata.obs[src_observation]
247251
.map(mappings)
248-
.fillna(adata.obs[src_observation])
249252
.astype("category")
250253
)
251254

255+
# Ensure that all categories are covered
256+
if adata.obs[dest_observation].isna().any():
257+
raise ValueError(
258+
"Not all unique values in the source observation are "
259+
"covered by the mappings. "
260+
"Please ensure that the mappings cover all unique values."
261+
)
262+
252263
return adata

src/spac/visualization.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from matplotlib.colors import ListedColormap, BoundaryNorm
1010

1111

12-
def tsne_plot(adata, ax=None, **kwargs):
12+
def tsne_plot(adata, color_column=None, ax=None, **kwargs):
1313
"""
1414
Visualize scatter plot in tSNE basis.
1515
@@ -18,6 +18,8 @@ def tsne_plot(adata, ax=None, **kwargs):
1818
adata : anndata.AnnData
1919
The AnnData object with t-SNE coordinates precomputed by the 'tsne'
2020
function and stored in 'adata.obsm["X_tsne"]'.
21+
color_column : str, optional
22+
The name of the column to use for coloring the scatter plot points.
2123
ax : matplotlib.axes.Axes, optional (default: None)
2224
A matplotlib axes object to plot on.
2325
If not provided, a new figure and axes will be created.
@@ -42,6 +44,17 @@ def tsne_plot(adata, ax=None, **kwargs):
4244
# Create a new figure and axes if not provided
4345
if ax is None:
4446
fig, ax = plt.subplots()
47+
else:
48+
fig = ax.get_figure()
49+
50+
if color_column and (color_column not in adata.obs.columns and
51+
color_column not in adata.var.columns):
52+
err_msg = f"'{color_column}' not found in adata.obs or adata.var."
53+
raise KeyError(err_msg)
54+
55+
# Add color column to the kwargs for the scanpy plot
56+
if color_column:
57+
kwargs['color'] = color_column
4558

4659
# Plot the t-SNE
4760
sc.pl.tsne(adata, ax=ax, **kwargs)
@@ -133,10 +146,10 @@ def histogram(adata, feature_name=None, observation_name=None, layer=None,
133146
fig, axs = plt.subplots(n_groups, 1, figsize=(5, 5*n_groups))
134147
if n_groups == 1:
135148
axs = [axs]
136-
for i, ax in enumerate(axs):
149+
for i, ax_i in enumerate(axs):
137150
sns.histplot(data=df[df[group_by] == groups[i]].dropna(),
138-
x=x, ax=ax, **kwargs)
139-
ax.set_title(groups[i])
151+
x=x, ax=ax_i, **kwargs)
152+
ax_i.set_title(groups[i])
140153
return fig, axs
141154

142155
sns.histplot(data=df, x=x, ax=ax, **kwargs)
@@ -556,7 +569,7 @@ def spatial_plot(
556569
raise ValueError(err_msg_ax)
557570

558571
if feature is not None:
559-
572+
560573
feature_index = feature_names.index(feature)
561574
feature_obs = feature + "spatial_plot"
562575
if vmin == -999:

tests/test_transformations/test_phenograph_clustering.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ def test_typical_case(self, mock_phenograph):
5252
self.assertEqual(self.adata.uns['phenograph_features'],
5353
self.features)
5454

55+
@patch('scanpy.external.tl.phenograph',
56+
return_value=(np.random.randint(0, 3, 100), {}))
57+
def test_layer_none_case(self, mock_phenograph):
58+
# This test checks if the function works correctly when layer is None.
59+
phenograph_clustering(self.adata, self.features, None)
60+
self.assertIn('phenograph', self.adata.obs)
61+
self.assertEqual(self.adata.uns['phenograph_features'],
62+
self.features)
63+
5564
def test_invalid_adata(self):
5665
# This test checks if the function raises a TypeError when the
5766
# adata argument is not an AnnData object.

tests/test_transformations/test_rename_observations.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,21 +61,6 @@ def test_invalid_mappings(self):
6161
{"5": "group_8"}
6262
)
6363

64-
def test_partial_mappings(self):
65-
"""Test rename_observations with partial mappings."""
66-
mappings = {"0": "group_8", "1": "group_2"}
67-
dest_observation = "renamed_observations"
68-
result = rename_observations(
69-
self.adata, "phenograph", dest_observation, mappings
70-
)
71-
expected = pd.Series(
72-
["group_8", "group_2", "group_8", "2", "group_2", "2"],
73-
index=self.adata.obs.index,
74-
name=dest_observation,
75-
dtype="category"
76-
)
77-
pd.testing.assert_series_equal(result.obs[dest_observation], expected)
78-
7964
def test_rename_observations_basic(self):
8065
"""Test basic functionality of rename_observations."""
8166
data_matrix = np.random.rand(3, 4)
@@ -157,6 +142,26 @@ def test_multiple_observations_to_one_group(self):
157142
all(renamed_clusters == ["group_0", "group_0", "group_0"])
158143
)
159144

145+
def test_not_all_categories_covered(self):
146+
"""
147+
Test rename_observations with mappings that do not cover
148+
all unique values in the source observation.
149+
"""
150+
mappings = {"0": "group_8", "1": "group_2"}
151+
with self.assertRaises(ValueError) as cm:
152+
rename_observations(
153+
self.adata,
154+
"phenograph",
155+
"incomplete_dest",
156+
mappings
157+
)
158+
self.assertEqual(
159+
str(cm.exception),
160+
"Not all unique values in the source observation are "
161+
"covered by the mappings. "
162+
"Please ensure that the mappings cover all unique values."
163+
)
164+
160165

161166
if __name__ == "__main__":
162167
unittest.main()

tests/test_visualization/test_histogram.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,24 @@ def test_histogram_observation_name(self):
3232
self.assertEqual(sum(p.get_height() for p in ax.patches), total_obs)
3333

3434
def test_histogram_feature_group_by(self):
35+
# Call the function with a feature_name and group_by argument,
36+
# setting together=False to create separate plots for each group.
3537
fig, axs = histogram(
3638
self.adata,
3739
feature_name='marker1',
3840
group_by='obs2',
3941
together=False
4042
)
43+
44+
# Check that the function returned a list of Axes objects,
45+
# one for each group. In this case,
46+
# we expect there to be 2 groups, as obs2 has 2 unique values.
4147
self.assertEqual(len(axs), 2)
4248

49+
# Check that each object in axs is indeed an Axes object.
50+
self.assertIsInstance(axs[0], mpl.axes.Axes)
51+
self.assertIsInstance(axs[1], mpl.axes.Axes)
52+
4353
def test_both_feature_and_observation(self):
4454
err_msg = ("Cannot pass both feature_name and "
4555
"observation_name, choose one")

tests/test_visualization/test_tsne_plot.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
import anndata
33
import numpy as np
4+
import matplotlib.pyplot as plt
45
from spac.visualization import tsne_plot
56

67

@@ -9,6 +10,8 @@ class TestTsnePlot(unittest.TestCase):
910
def setUp(self):
1011
self.adata = anndata.AnnData(X=np.random.rand(10, 10))
1112
self.adata.obsm['X_tsne'] = np.random.rand(10, 2)
13+
self.adata.obs['color_column'] = np.random.choice(
14+
['A', 'B', 'C'], size=10)
1215

1316
def test_invalid_input_type(self):
1417
with self.assertRaises(ValueError) as cm:
@@ -24,6 +27,25 @@ def test_no_tsne_data(self):
2427
"adata.obsm does not contain 'X_tsne',"
2528
" perform t-SNE transformation first.")
2629

30+
def test_color_column(self):
31+
fig, ax = tsne_plot(self.adata, color_column='color_column')
32+
self.assertIsNotNone(fig)
33+
self.assertIsNotNone(ax)
34+
35+
def test_ax_provided(self):
36+
fig, ax_provided = plt.subplots()
37+
fig_returned, ax_returned = tsne_plot(self.adata, ax=ax_provided)
38+
self.assertIs(fig, fig_returned)
39+
self.assertIs(ax_provided, ax_returned)
40+
41+
def test_color_column_invalid(self):
42+
with self.assertRaises(KeyError) as cm:
43+
tsne_plot(self.adata, color_column='invalid_column')
44+
self.assertEqual(
45+
str(cm.exception),
46+
"\"'invalid_column' not found in adata.obs or adata.var.\""
47+
)
48+
2749

2850
if __name__ == '__main__':
2951
unittest.main()

0 commit comments

Comments
 (0)