Skip to content
50 changes: 36 additions & 14 deletions src/spac/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def visualize_2D_scatter(
x, y, labels=None, point_size=None, theme=None,
ax=None, annotate_centers=False,
x_axis_title='Component 1', y_axis_title='Component 2', plot_title=None,
color_representation=None, **kwargs
color_representation=None, color_map=None, **kwargs
):
"""
Visualize 2D data using plt.scatter.
Expand Down Expand Up @@ -65,6 +65,8 @@ def visualize_2D_scatter(
Title for the plot.
color_representation : str, optional
Description of what the colors represent.
color_map : dictionary, optional
Dictionary containing colors for label annotations.
**kwargs
Additional keyword arguments passed to plt.scatter.

Expand All @@ -83,6 +85,8 @@ def visualize_2D_scatter(
raise ValueError("x and y must have the same length.")
if labels is not None and len(labels) != len(x):
raise ValueError("Labels length should match x and y length.")
if color_map is not None and not isinstance(color_map, dict):
raise ValueError("`color_map` must be a dict mapping label→color.")

# Define color themes
themes = {
Expand Down Expand Up @@ -136,20 +140,21 @@ def visualize_2D_scatter(
"Categorical."
)

# Combine colors from multiple colormaps
cmap1 = plt.get_cmap('tab20')
cmap2 = plt.get_cmap('tab20b')
cmap3 = plt.get_cmap('tab20c')
colors = cmap1.colors + cmap2.colors + cmap3.colors

# Use the number of unique clusters to set the colormap length
cmap = ListedColormap(colors[:len(unique_clusters)])
cluster_to_color = color_map if color_map is not None else {
str(cluster): colors[i % len(colors)]
for i, cluster in enumerate(unique_clusters)
}

for idx, cluster in enumerate(unique_clusters):
mask = np.array(labels) == cluster
color = cluster_to_color.get(str(cluster), 'gray')
ax.scatter(
x[mask], y[mask],
color=cmap(idx),
color=color,
label=cluster,
s=point_size
)
Expand Down Expand Up @@ -204,6 +209,7 @@ def dimensionality_reduction_plot(
layer=None,
ax=None,
associated_table=None,
color_map = None,
**kwargs):
"""
Visualize scatter plot in PCA, t-SNE, UMAP, or associated table.
Expand Down Expand Up @@ -232,6 +238,9 @@ def dimensionality_reduction_plot(
associated_table : str, optional (default: None)
Name of the key in `obsm` that contains the numpy array. Takes
precedence over `method`
color_map : str, optional (default: None)
Name of the key in adata.uns that contains color-mapping for
the plot
**kwargs
Parameters passed to visualize_2D_scatter function,
including point_size.
Expand All @@ -258,6 +267,14 @@ def dimensionality_reduction_plot(
if feature:
check_feature(adata, features=[feature])

color_mapping = None
if color_map is not None:
color_mapping = get_defined_color_map(
adata,
defined_color_map=color_map,
annotations=annotation
)

# Validate the method and check if the necessary data exists in adata.obsm
if associated_table is None:
valid_methods = ['tsne', 'umap', 'pca']
Expand Down Expand Up @@ -294,16 +311,20 @@ def dimensionality_reduction_plot(
x, y = adata.obsm[key].T

# Determine coloring scheme
if annotation:
if color_mapping is None:
if annotation:
color_values = adata.obs[annotation].astype('category').values
color_representation = annotation
elif feature:
data_src = adata.layers[layer] if layer else adata.X
color_values = data_src[:, adata.var_names == feature].squeeze()
color_representation = feature
else:
color_values = None
color_representation = None
else:
color_values = adata.obs[annotation].astype('category').values
color_representation = annotation
elif feature:
data_source = adata.layers[layer] if layer else adata.X
color_values = data_source[:, adata.var_names == feature].squeeze()
color_representation = feature
else:
color_values = None
color_representation = None

# Set axis titles based on method and color representation
if method == 'tsne':
Expand Down Expand Up @@ -338,6 +359,7 @@ def dimensionality_reduction_plot(
y_axis_title=y_axis_title,
plot_title=plot_title,
color_representation=color_representation,
color_map=color_mapping,
**kwargs
)

Expand Down