99from matplotlib .colors import ListedColormap , BoundaryNorm
1010
1111
12- def tsne_plot (adata , ax = None , ** kwargs ):
12+ def tsne_plot (adata , color_column = None , ax = None , ** kwargs ):
1313 """
1414 Visualize scatter plot in tSNE basis.
1515
@@ -18,6 +18,8 @@ def tsne_plot(adata, ax=None, **kwargs):
1818 adata : anndata.AnnData
1919 The AnnData object with t-SNE coordinates precomputed by the 'tsne'
2020 function and stored in 'adata.obsm["X_tsne"]'.
21+ color_column : str, optional
22+ The name of the column to use for coloring the scatter plot points.
2123 ax : matplotlib.axes.Axes, optional (default: None)
2224 A matplotlib axes object to plot on.
2325 If not provided, a new figure and axes will be created.
@@ -42,6 +44,17 @@ def tsne_plot(adata, ax=None, **kwargs):
4244 # Create a new figure and axes if not provided
4345 if ax is None :
4446 fig , ax = plt .subplots ()
47+ else :
48+ fig = ax .get_figure ()
49+
50+ if color_column and (color_column not in adata .obs .columns and
51+ color_column not in adata .var .columns ):
52+ err_msg = f"'{ color_column } ' not found in adata.obs or adata.var."
53+ raise KeyError (err_msg )
54+
55+ # Add color column to the kwargs for the scanpy plot
56+ if color_column :
57+ kwargs ['color' ] = color_column
4558
4659 # Plot the t-SNE
4760 sc .pl .tsne (adata , ax = ax , ** kwargs )
@@ -133,10 +146,10 @@ def histogram(adata, feature_name=None, observation_name=None, layer=None,
133146 fig , axs = plt .subplots (n_groups , 1 , figsize = (5 , 5 * n_groups ))
134147 if n_groups == 1 :
135148 axs = [axs ]
136- for i , ax in enumerate (axs ):
149+ for i , ax_i in enumerate (axs ):
137150 sns .histplot (data = df [df [group_by ] == groups [i ]].dropna (),
138- x = x , ax = ax , ** kwargs )
139- ax .set_title (groups [i ])
151+ x = x , ax = ax_i , ** kwargs )
152+ ax_i .set_title (groups [i ])
140153 return fig , axs
141154
142155 sns .histplot (data = df , x = x , ax = ax , ** kwargs )
@@ -556,7 +569,7 @@ def spatial_plot(
556569 raise ValueError (err_msg_ax )
557570
558571 if feature is not None :
559-
572+
560573 feature_index = feature_names .index (feature )
561574 feature_obs = feature + "spatial_plot"
562575 if vmin == - 999 :
0 commit comments