From 2efc9c3190caa6f1d0efb3fe94df6ef379c75706 Mon Sep 17 00:00:00 2001 From: Liza Shchehlik Date: Thu, 17 Apr 2025 17:30:21 -0400 Subject: [PATCH 1/5] feat(auth): added pin_colors implementation to 2d_scatter function --- src/spac/visualization.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 0ab0ee11..98fc4544 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -35,7 +35,7 @@ def visualize_2D_scatter( x, y, labels=None, point_size=None, theme=None, ax=None, annotate_centers=False, x_axis_title='Component 1', y_axis_title='Component 2', plot_title=None, - color_representation=None, **kwargs + color_representation=None, defined_color_map=None, **kwargs ): """ Visualize 2D data using plt.scatter. @@ -65,6 +65,8 @@ def visualize_2D_scatter( Title for the plot. color_representation : str, optional Description of what the colors represent. + defined_color_map : dictionary, optional + Dictionary containing colors for label annotations. **kwargs Additional keyword arguments passed to plt.scatter. @@ -83,6 +85,10 @@ def visualize_2D_scatter( raise ValueError("x and y must have the same length.") if labels is not None and len(labels) != len(x): raise ValueError("Labels length should match x and y length.") + if defined_color_map is not None: + if not isinstance(defined_color_map, dict): + raise ValueError("`defined_color_map` must be a dict mapping label→color.") + color_dict = defined_color_map # Define color themes themes = { @@ -136,20 +142,25 @@ def visualize_2D_scatter( "Categorical." ) - # Combine colors from multiple colormaps - cmap1 = plt.get_cmap('tab20') - cmap2 = plt.get_cmap('tab20b') - cmap3 = plt.get_cmap('tab20c') - colors = cmap1.colors + cmap2.colors + cmap3.colors - - # Use the number of unique clusters to set the colormap length - cmap = ListedColormap(colors[:len(unique_clusters)]) + if defined_color_map is not None: + cluster_to_color = color_dict + else: + # fall back to your combined tab20 palettes + cmap1 = plt.get_cmap('tab20') + cmap2 = plt.get_cmap('tab20b') + cmap3 = plt.get_cmap('tab20c') + colors = cmap1.colors + cmap2.colors + cmap3.colors + cluster_to_color = { + str(cluster): colors[i % len(colors)] + for i, cluster in enumerate(unique_clusters) + } for idx, cluster in enumerate(unique_clusters): mask = np.array(labels) == cluster + color = cluster_to_color.get(str(cluster), 'gray') ax.scatter( x[mask], y[mask], - color=cmap(idx), + color=color, label=cluster, s=point_size ) From 013dd90354b789968e85d0ad4751a38d328f3e7e Mon Sep 17 00:00:00 2001 From: Liza Shchehlik Date: Thu, 17 Apr 2025 17:59:14 -0400 Subject: [PATCH 2/5] feat(auth): added pin_colors color map integration into dimensionality_reduction_plot --- src/spac/visualization.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 98fc4544..922c7f85 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -215,6 +215,7 @@ def dimensionality_reduction_plot( layer=None, ax=None, associated_table=None, + color_map = None, **kwargs): """ Visualize scatter plot in PCA, t-SNE, UMAP, or associated table. @@ -243,6 +244,9 @@ def dimensionality_reduction_plot( associated_table : str, optional (default: None) Name of the key in `obsm` that contains the numpy array. Takes precedence over `method` + color_map : str, optional (default: None) + Name of the key in adata.uns that contains color-mapping for + the plot **kwargs Parameters passed to visualize_2D_scatter function, including point_size. @@ -269,6 +273,14 @@ def dimensionality_reduction_plot( if feature: check_feature(adata, features=[feature]) + color_mapping = None + if color_map is not None: + color_mapping = get_defined_color_map( + adata, + defined_color_map=color_map, + annotations=annotation + ) + # Validate the method and check if the necessary data exists in adata.obsm if associated_table is None: valid_methods = ['tsne', 'umap', 'pca'] @@ -305,16 +317,20 @@ def dimensionality_reduction_plot( x, y = adata.obsm[key].T # Determine coloring scheme - if annotation: + if color_mapping is None: + if annotation: + color_values = adata.obs[annotation].astype('category').values + color_representation = annotation + elif feature: + data_src = adata.layers[layer] if layer else adata.X + color_values = data_src[:, adata.var_names == feature].squeeze() + color_representation = feature + else: + color_values = None + color_representation = None + else: color_values = adata.obs[annotation].astype('category').values color_representation = annotation - elif feature: - data_source = adata.layers[layer] if layer else adata.X - color_values = data_source[:, adata.var_names == feature].squeeze() - color_representation = feature - else: - color_values = None - color_representation = None # Set axis titles based on method and color representation if method == 'tsne': @@ -349,6 +365,7 @@ def dimensionality_reduction_plot( y_axis_title=y_axis_title, plot_title=plot_title, color_representation=color_representation, + defined_color_map=color_mapping, **kwargs ) From 8262ebc9c75edf5f64945f300e691f40f028287b Mon Sep 17 00:00:00 2001 From: Liza Shchehlik Date: Thu, 17 Apr 2025 18:01:13 -0400 Subject: [PATCH 3/5] style(header): changed parameter name to color_map instead of defined_color_map --- src/spac/visualization.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 98fc4544..b257beb7 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -35,7 +35,7 @@ def visualize_2D_scatter( x, y, labels=None, point_size=None, theme=None, ax=None, annotate_centers=False, x_axis_title='Component 1', y_axis_title='Component 2', plot_title=None, - color_representation=None, defined_color_map=None, **kwargs + color_representation=None, color_map=None, **kwargs ): """ Visualize 2D data using plt.scatter. @@ -65,7 +65,7 @@ def visualize_2D_scatter( Title for the plot. color_representation : str, optional Description of what the colors represent. - defined_color_map : dictionary, optional + color_map : dictionary, optional Dictionary containing colors for label annotations. **kwargs Additional keyword arguments passed to plt.scatter. @@ -85,10 +85,10 @@ def visualize_2D_scatter( raise ValueError("x and y must have the same length.") if labels is not None and len(labels) != len(x): raise ValueError("Labels length should match x and y length.") - if defined_color_map is not None: - if not isinstance(defined_color_map, dict): - raise ValueError("`defined_color_map` must be a dict mapping label→color.") - color_dict = defined_color_map + if color_map is not None: + if not isinstance(color_map, dict): + raise ValueError("`color_map` must be a dict mapping label→color.") + color_dict = color_map # Define color themes themes = { @@ -142,7 +142,7 @@ def visualize_2D_scatter( "Categorical." ) - if defined_color_map is not None: + if color_map is not None: cluster_to_color = color_dict else: # fall back to your combined tab20 palettes From 57b7938a613e84280b503adc1c8a64f049187a7a Mon Sep 17 00:00:00 2001 From: Liza Shchehlik Date: Thu, 17 Apr 2025 18:14:27 -0400 Subject: [PATCH 4/5] style(function): changed function call to 2dscatter to apply changes of header keyword argument --- src/spac/visualization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index e87c889e..7c0783ea 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -365,7 +365,7 @@ def dimensionality_reduction_plot( y_axis_title=y_axis_title, plot_title=plot_title, color_representation=color_representation, - defined_color_map=color_mapping, + color_map=color_mapping, **kwargs ) From 6476e950910bce5cc487940712881540a838f3e9 Mon Sep 17 00:00:00 2001 From: Liza Shchehlik Date: Thu, 24 Apr 2025 16:59:34 -0400 Subject: [PATCH 5/5] refactor(func): Added changes made to scatterplot to this branch --- src/spac/visualization.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 7c0783ea..38513c39 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -85,10 +85,8 @@ def visualize_2D_scatter( raise ValueError("x and y must have the same length.") if labels is not None and len(labels) != len(x): raise ValueError("Labels length should match x and y length.") - if color_map is not None: - if not isinstance(color_map, dict): - raise ValueError("`color_map` must be a dict mapping label→color.") - color_dict = color_map + if color_map is not None and not isinstance(color_map, dict): + raise ValueError("`color_map` must be a dict mapping label→color.") # Define color themes themes = { @@ -142,18 +140,14 @@ def visualize_2D_scatter( "Categorical." ) - if color_map is not None: - cluster_to_color = color_dict - else: - # fall back to your combined tab20 palettes - cmap1 = plt.get_cmap('tab20') - cmap2 = plt.get_cmap('tab20b') - cmap3 = plt.get_cmap('tab20c') - colors = cmap1.colors + cmap2.colors + cmap3.colors - cluster_to_color = { - str(cluster): colors[i % len(colors)] - for i, cluster in enumerate(unique_clusters) - } + cmap1 = plt.get_cmap('tab20') + cmap2 = plt.get_cmap('tab20b') + cmap3 = plt.get_cmap('tab20c') + colors = cmap1.colors + cmap2.colors + cmap3.colors + cluster_to_color = color_map if color_map is not None else { + str(cluster): colors[i % len(colors)] + for i, cluster in enumerate(unique_clusters) + } for idx, cluster in enumerate(unique_clusters): mask = np.array(labels) == cluster