diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 0ab0ee11..38513c39 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -35,7 +35,7 @@ def visualize_2D_scatter( x, y, labels=None, point_size=None, theme=None, ax=None, annotate_centers=False, x_axis_title='Component 1', y_axis_title='Component 2', plot_title=None, - color_representation=None, **kwargs + color_representation=None, color_map=None, **kwargs ): """ Visualize 2D data using plt.scatter. @@ -65,6 +65,8 @@ def visualize_2D_scatter( Title for the plot. color_representation : str, optional Description of what the colors represent. + color_map : dictionary, optional + Dictionary containing colors for label annotations. **kwargs Additional keyword arguments passed to plt.scatter. @@ -83,6 +85,8 @@ def visualize_2D_scatter( raise ValueError("x and y must have the same length.") if labels is not None and len(labels) != len(x): raise ValueError("Labels length should match x and y length.") + if color_map is not None and not isinstance(color_map, dict): + raise ValueError("`color_map` must be a dict mapping label→color.") # Define color themes themes = { @@ -136,20 +140,21 @@ def visualize_2D_scatter( "Categorical." ) - # Combine colors from multiple colormaps cmap1 = plt.get_cmap('tab20') cmap2 = plt.get_cmap('tab20b') cmap3 = plt.get_cmap('tab20c') colors = cmap1.colors + cmap2.colors + cmap3.colors - - # Use the number of unique clusters to set the colormap length - cmap = ListedColormap(colors[:len(unique_clusters)]) + cluster_to_color = color_map if color_map is not None else { + str(cluster): colors[i % len(colors)] + for i, cluster in enumerate(unique_clusters) + } for idx, cluster in enumerate(unique_clusters): mask = np.array(labels) == cluster + color = cluster_to_color.get(str(cluster), 'gray') ax.scatter( x[mask], y[mask], - color=cmap(idx), + color=color, label=cluster, s=point_size ) @@ -204,6 +209,7 @@ def dimensionality_reduction_plot( layer=None, ax=None, associated_table=None, + color_map = None, **kwargs): """ Visualize scatter plot in PCA, t-SNE, UMAP, or associated table. @@ -232,6 +238,9 @@ def dimensionality_reduction_plot( associated_table : str, optional (default: None) Name of the key in `obsm` that contains the numpy array. Takes precedence over `method` + color_map : str, optional (default: None) + Name of the key in adata.uns that contains color-mapping for + the plot **kwargs Parameters passed to visualize_2D_scatter function, including point_size. @@ -258,6 +267,14 @@ def dimensionality_reduction_plot( if feature: check_feature(adata, features=[feature]) + color_mapping = None + if color_map is not None: + color_mapping = get_defined_color_map( + adata, + defined_color_map=color_map, + annotations=annotation + ) + # Validate the method and check if the necessary data exists in adata.obsm if associated_table is None: valid_methods = ['tsne', 'umap', 'pca'] @@ -294,16 +311,20 @@ def dimensionality_reduction_plot( x, y = adata.obsm[key].T # Determine coloring scheme - if annotation: + if color_mapping is None: + if annotation: + color_values = adata.obs[annotation].astype('category').values + color_representation = annotation + elif feature: + data_src = adata.layers[layer] if layer else adata.X + color_values = data_src[:, adata.var_names == feature].squeeze() + color_representation = feature + else: + color_values = None + color_representation = None + else: color_values = adata.obs[annotation].astype('category').values color_representation = annotation - elif feature: - data_source = adata.layers[layer] if layer else adata.X - color_values = data_source[:, adata.var_names == feature].squeeze() - color_representation = feature - else: - color_values = None - color_representation = None # Set axis titles based on method and color representation if method == 'tsne': @@ -338,6 +359,7 @@ def dimensionality_reduction_plot( y_axis_title=y_axis_title, plot_title=plot_title, color_representation=color_representation, + color_map=color_mapping, **kwargs )