Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/spac/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def visualize_2D_scatter(
x, y, labels=None, point_size=None, theme=None,
ax=None, annotate_centers=False,
x_axis_title='Component 1', y_axis_title='Component 2', plot_title=None,
color_representation=None, **kwargs
color_representation=None, color_map=None, **kwargs
):
"""
Visualize 2D data using plt.scatter.
Expand Down Expand Up @@ -65,6 +65,8 @@ def visualize_2D_scatter(
Title for the plot.
color_representation : str, optional
Description of what the colors represent.
color_map : dictionary, optional
Dictionary containing colors for label annotations.
**kwargs
Additional keyword arguments passed to plt.scatter.

Expand All @@ -83,6 +85,8 @@ def visualize_2D_scatter(
raise ValueError("x and y must have the same length.")
if labels is not None and len(labels) != len(x):
raise ValueError("Labels length should match x and y length.")
if color_map is not None and not isinstance(color_map, dict):
raise ValueError("`color_map` must be a dict mapping label→color.")

# Define color themes
themes = {
Expand Down Expand Up @@ -136,20 +140,21 @@ def visualize_2D_scatter(
"Categorical."
)

# Combine colors from multiple colormaps
cmap1 = plt.get_cmap('tab20')
cmap2 = plt.get_cmap('tab20b')
cmap3 = plt.get_cmap('tab20c')
colors = cmap1.colors + cmap2.colors + cmap3.colors

# Use the number of unique clusters to set the colormap length
cmap = ListedColormap(colors[:len(unique_clusters)])
cluster_to_color = color_map if color_map is not None else {
str(cluster): colors[i % len(colors)]
for i, cluster in enumerate(unique_clusters)
}

for idx, cluster in enumerate(unique_clusters):
mask = np.array(labels) == cluster
color = cluster_to_color.get(str(cluster), 'gray')
ax.scatter(
x[mask], y[mask],
color=cmap(idx),
color=color,
label=cluster,
s=point_size
)
Expand Down