diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 0b87d197..97a73b53 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -20,12 +20,11 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') - def visualize_2D_scatter( - x, y, labels=None, point_size=None, theme=None, + x, y,labels=None, point_size=None, theme=None, ax=None, annotate_centers=False, x_axis_title='Component 1', y_axis_title='Component 2', plot_title=None, - color_representation=None, **kwargs + color_representation=None, color_map=None, **kwargs ): """ Visualize 2D data using plt.scatter. @@ -39,10 +38,7 @@ def visualize_2D_scatter( point_size : float, optional Size of the points. If None, it will be automatically determined. theme : str, optional - Color theme for the plot. Defaults to 'viridis' if theme not - recognized. For a list of supported themes, refer to Matplotlib's - colormap documentation: - https://matplotlib.org/stable/tutorials/colors/colormaps.html + Color theme for the plot. Defaults to 'viridis' if theme is not recognized. ax : matplotlib.axes.Axes, optional (default: None) Matplotlib axis object. If None, a new one is created. annotate_centers : bool, optional (default: False) @@ -55,6 +51,8 @@ def visualize_2D_scatter( Title for the plot. color_representation : str, optional Description of what the colors represent. + color_map : str, optional + Provides color dictionary for scatterplot **kwargs Additional keyword arguments passed to plt.scatter. @@ -71,8 +69,7 @@ def visualize_2D_scatter( raise ValueError("x and y must be array-like.") if len(x) != len(y): raise ValueError("x and y must have the same length.") - if labels is not None and len(labels) != len(x): - raise ValueError("Labels length should match x and y length.") + # Define color themes themes = { @@ -88,10 +85,9 @@ def visualize_2D_scatter( } if theme and theme not in themes: - error_msg = ( - f"Theme '{theme}' not recognized. Please use a valid theme." - ) + error_msg = f"Theme '{theme}' not recognized. Please use a valid theme." raise ValueError(error_msg) + cmap = themes.get(theme, plt.get_cmap('viridis')) # Determine point size @@ -99,7 +95,9 @@ def visualize_2D_scatter( if point_size is None: point_size = 5000 / num_points - # Get figure size and fontsize from kwargs or set defaults + + # Get figure size from kwargs or set defaults + fig_width = kwargs.get('fig_width', 10) fig_height = kwargs.get('fig_height', 8) fontsize = kwargs.get('fontsize', 12) @@ -109,13 +107,64 @@ def visualize_2D_scatter( else: fig = ax.figure - # Plotting logic - if labels is not None: - # Check if labels are categorical - if pd.api.types.is_categorical_dtype(labels): + if labels is not None and len(labels) != len(x): + raise ValueError("Labels length should match x and y length.") - # Determine how to access the categories based on - # the type of 'labels' + #default + scatter = ax.scatter(x, y, s=point_size, c='gray', **kwargs) + if color_map is not None and labels is not None: + # Check if the color_representation exists in adata.obs + colors = [color_map.get(label, 'gray') for label in labels] + scatter = ax.scatter(x, y, c=colors, s=point_size, **kwargs) + + # Create legend for pin_color mapping + if isinstance(labels[0], str): # Categorical data + # Create legend handles for each unique label + unique_labels = set(labels) + handles = [ + plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_map.get(label, 'gray'), markersize=10) + for label in unique_labels + ] + legend_labels = list(unique_labels) + + # Add the legend to the plot + ax.legend( + handles, + legend_labels, + title="Categories", + bbox_to_anchor=(1.05, 1), + loc='upper left' + ) + + # Annotate cluster centers if required (for categorical data) + if annotate_centers: + unique_labels = set(labels) + for cluster in unique_labels: + # Mask the data for the current cluster + mask = [label == cluster for label in labels] + cluster_x = [x[i] for i in range(len(x)) if mask[i]] + cluster_y = [y[i] for i in range(len(y)) if mask[i]] + + # Compute cluster center + center_x = np.mean(cluster_x) + center_y = np.mean(cluster_y) + + # Annotate the center + ax.text( + center_x, center_y, cluster, + fontsize=9, ha='center', va='center', + bbox=dict(boxstyle='round,pad=0.3', edgecolor='gray', facecolor='white', alpha=0.7) + ) + + elif labels is not None: + # Check if labels are continuous (numeric) + if pd.api.types.is_numeric_dtype(labels): + scatter = ax.scatter(x, y, c=labels, cmap=cmap, s=point_size, **kwargs) + cbar = plt.colorbar(scatter, ax=ax) + cbar.set_label('Label Intensity') + + # Check if labels are categorical + elif pd.api.types.is_categorical_dtype(labels): if isinstance(labels, pd.Series): unique_clusters = labels.cat.categories elif isinstance(labels, pd.Categorical): @@ -125,16 +174,7 @@ def visualize_2D_scatter( "Expected labels to be of type Series[Categorical] or " "Categorical." ) - - # Combine colors from multiple colormaps - cmap1 = plt.get_cmap('tab20') - cmap2 = plt.get_cmap('tab20b') - cmap3 = plt.get_cmap('tab20c') - colors = cmap1.colors + cmap2.colors + cmap3.colors - - # Use the number of unique clusters to set the colormap length - cmap = ListedColormap(colors[:len(unique_clusters)]) - + cmap = plt.get_cmap('tab20', len(unique_clusters)) for idx, cluster in enumerate(unique_clusters): mask = np.array(labels) == cluster ax.scatter( @@ -145,6 +185,7 @@ def visualize_2D_scatter( ) print(f"Cluster: {cluster}, Points: {np.sum(mask)}") + # Annotate cluster centers if required if annotate_centers: center_x = np.mean(x[mask]) center_y = np.mean(y[mask]) @@ -152,28 +193,20 @@ def visualize_2D_scatter( center_x, center_y, cluster, fontsize=fontsize, ha='center', va='center' ) - # Create a custom legend with color representation - ax.legend( - loc='best', - bbox_to_anchor=(1.25, 1), # Adjusting position - title=f"Color represents: {color_representation}" - ) + ax.legend(bbox_to_anchor=(1.05, 1),title=color_representation, loc='upper left', borderaxespad=0.) else: - # If labels are continuous - scatter = ax.scatter( - x, y, c=labels, cmap=cmap, s=point_size, **kwargs - ) - plt.colorbar(scatter, ax=ax) - if color_representation is not None: - ax.set_title( - f"{plot_title}\nColor represents: {color_representation}" - ) - else: - scatter = ax.scatter(x, y, c='gray', s=point_size, **kwargs) + scatter = ax.scatter(x, y, c=labels, cmap=cmap, s=point_size, **kwargs) + cbar = plt.colorbar(scatter, ax=ax) + cbar.set_label('Label Intensity') - # Equal aspect ratio for the axes - ax.set_aspect('equal', 'datalim') + # Set axis labels and title + ax.set_xlabel(x_axis_title) + ax.set_ylabel(y_axis_title) + if plot_title is not None: + ax.set_title(plot_title) + + ax.set_aspect('equal', adjustable='box') # Set axis labels ax.set_xlabel(x_axis_title) @@ -970,8 +1003,10 @@ def spatial_plot( feature=None, layer=None, ax=None, + pin_color_rules=None, **kwargs ): + """ Generate the spatial plot of selected features Parameters @@ -1000,6 +1035,8 @@ def spatial_plot( The matplotlib Axes containing the analysis plots. The returned ax is the passed ax or new ax created. Only works if plotting a single component. + pin_color_rules : str, optional + Dictionary name in `adata.uns` for custom colors. **kwargs Arguments to pass to matplotlib.pyplot.scatter() Returns @@ -1008,27 +1045,27 @@ def spatial_plot( """ err_msg_layer = "The 'layer' parameter must be a string, " + \ - f"got {str(type(layer))}" + f"got {str(type(layer))}." err_msg_feature = "The 'feature' parameter must be a string, " + \ - f"got {str(type(feature))}" + f"got {str(type(feature))}." err_msg_annotation = "The 'annotation' parameter must be a string, " + \ - f"got {str(type(annotation))}" + f"got {str(type(annotation))}." err_msg_feat_annotation_coe = "Both annotation and feature are passed, " +\ - "please provide sinle input." + "please provide single input." err_msg_feat_annotation_non = "Both annotation and feature are None, " + \ "please provide single input." err_msg_spot_size = "The 'spot_size' parameter must be an integer, " + \ - f"got {str(type(spot_size))}" + f"got {str(type(spot_size))}." err_msg_alpha_type = "The 'alpha' parameter must be a float," + \ - f"got {str(type(alpha))}" + f"got {str(type(alpha))}." err_msg_alpha_value = "The 'alpha' parameter must be between " + \ - f"0 and 1 (inclusive), got {str(alpha)}" + f"0 and 1 (inclusive), got {str(alpha)}." err_msg_vmin = "The 'vmin' parameter must be a float or an int, " + \ - f"got {str(type(vmin))}" + f"got {str(type(vmin))}." err_msg_vmax = "The 'vmax' parameter must be a float or an int, " + \ - f"got {str(type(vmax))}" + f"got {str(type(vmax))}." err_msg_ax = "The 'ax' parameter must be an instance " + \ - f"of matplotlib.axes.Axes, got {str(type(ax))}" + f"of matplotlib.axes.Axes, got {str(type(ax))}." if adata is None: raise ValueError("The input dataset must not be None.") @@ -1042,7 +1079,7 @@ def spatial_plot( raise ValueError(err_msg_layer) if layer is not None and layer not in adata.layers.keys(): - err_msg_layer_exist = f"Layer {layer} does not exists, " + \ + err_msg_layer_exist = f"Layer {layer} does not exist, " + \ f"available layers are {str(adata.layers.keys())}" raise ValueError(err_msg_layer_exist) @@ -1068,7 +1105,7 @@ def spatial_plot( if annotation is not None and annotation not in annotation_names: error_text = f'The annotation "{annotation}"' + \ - 'not found in the dataset.' + \ + ' not found in the dataset.' + \ f" Existing annotations are: {annotation_names_str}" raise ValueError(error_text) @@ -1107,38 +1144,86 @@ def spatial_plot( if ax is not None and not isinstance(ax, plt.Axes): raise ValueError(err_msg_ax) + feature_names = adata.var_names.tolist() + annotation_names = adata.obs.columns.tolist() + if feature is not None: + # Get the integer index of the feature + try: + feature_index = adata.var_names.get_loc(feature) # Ensure it's an integer + except KeyError: + raise ValueError(f"Feature '{feature}' not found in the dataset.") + + if layer is None: + # If layer is None, use adata.X by default + feature_values = adata.X[:, feature_index] + elif isinstance(layer, str): + # If layer is a string, use it to access the correct layer in adata.layers + if layer not in adata.layers: + raise ValueError(f"Layer '{layer}' not found in adata.layers. Available layers: {adata.layers.keys()}") + feature_values = adata.layers[layer][:, feature_index] + elif isinstance(layer, np.ndarray): + # If layer is a numpy.ndarray, treat it as the data for the feature + feature_values = layer[:, feature_index] + else: + # If layer is neither a string nor numpy.ndarray, raise an error + raise ValueError(f"Expected 'layer' to be a string or numpy.ndarray, but got {type(layer)}.") - feature_index = feature_names.index(feature) - feature_annotation = feature + "spatial_plot" + # Handle vmin and vmax if not specified if vmin == -999: - vmin = np.min(layer[:, feature_index]) + vmin = np.min(feature_values) if vmax == -999: - vmax = np.max(layer[:, feature_index]) - adata.obs[feature_annotation] = layer[:, feature_index] + vmax = np.max(feature_values) + + feature_annotation = feature + "_spatial_plot" + adata.obs[feature_annotation] = feature_values.flatten() color_region = feature_annotation - else: - color_region = annotation - vmin = None - vmax = None if ax is None: fig = plt.figure() ax = fig.add_subplot(1, 1, 1) - ax = sc.pl.spatial( - adata=adata, - layer=layer, - color=color_region, - spot_size=spot_size, + + spatial_coords = adata.obsm['spatial'] + x_coords, y_coords = spatial_coords[:, 0], spatial_coords[:, 1] + + # Color handling based on the annotation + if pin_color_rules: + color_map = adata.uns.get(pin_color_rules, {}) + elif '_spac_colors' in adata.uns: + color_map = adata.uns['_spac_colors'] + else: + color_map = {} + + # Check if annotation is valid and exists in adata.obs columns + if annotation is not None and annotation not in adata.obs.columns: + error_text = f'The annotation "{annotation}"' + \ + ' not found in the dataset.' + \ + f" Existing annotations are: {', '.join(adata.obs.columns.tolist())}" + raise ValueError(error_text) + + # If annotation is provided, use it for color mapping, otherwise default to gray + if annotation is not None: + colors = [color_map.get(label, 'gray') for label in adata.obs[annotation]] + else: + colors = ['gray'] * len(x_coords) # Default color if no annotation is provided + + # Create scatter plot + scatter = ax.scatter( + x=x_coords, + y=y_coords, + c=colors, + s=spot_size, alpha=alpha, - vmin=vmin, - vmax=vmax, - ax=ax, - show=False, - **kwargs) + **kwargs + ) + + # Set color limits for feature-based visualization + if feature is not None: + scatter.set_clim(vmin, vmax) + + return [ax] - return ax def boxplot(adata, annotation=None, second_annotation=None, layer=None, diff --git a/tests/test_visualization/test_spatial_plot.py b/tests/test_visualization/test_spatial_plot.py index 6d025477..15e47c6c 100644 --- a/tests/test_visualization/test_spatial_plot.py +++ b/tests/test_visualization/test_spatial_plot.py @@ -101,12 +101,11 @@ def test_invalid_annotation_name(self): annotation='annotation4' ) error_msg = str(cm.exception) - err_msg_exp = 'The annotation "annotation4"' +\ - 'not found in the dataset.' +\ - ' Existing annotations are: annotation1,' +\ - ' annotation2, annotation3' + err_msg_exp = 'The annotation "annotation4" not found in the dataset.' + \ + ' Existing annotations are: annotation1, annotation2, annotation3' self.assertEqual(error_msg, err_msg_exp) + def test_invalid_feature_name(self): # Test when feature name is not found in the layer with self.assertRaises(ValueError) as cm: @@ -201,16 +200,15 @@ def mock_spatial( show, **kwargs): # Assert that the inputs match the expected values - self.assertEqual(layer, None) - self.assertEqual(feature, 'Intensity_10') - self.assertEqual(spot_size, self.spot_size) - self.assertEqual(alpha, self.alpha) - self.assertEqual(vmin, 0) - self.assertEqual(vmax, 100) - self.assertIsInstance(ax, plt.Axes) - self.assertFalse(show) - # Return a list containing the ax object to mimic - # the behavior of the spatial function + self.assertEqual(layer, None) # Ensuring layer is None as expected + self.assertEqual(feature, 'Intensity_10') # Checking feature value + self.assertEqual(spot_size, self.spot_size) # Checking spot size + self.assertEqual(alpha, self.alpha) # Checking alpha + self.assertEqual(vmin, 0) # Checking vmin + self.assertEqual(vmax, 100) # Checking vmax + self.assertIsInstance(ax, plt.Axes) # Ensuring ax is an Axes object + self.assertFalse(show) # Ensuring show is False + # Return a list containing the ax object to mimic the behavior of the spatial function return [ax] # Mock the spatial function with the mock_spatial function @@ -222,7 +220,7 @@ def mock_spatial( rect=[0, 0, 1, 1] ) - # Call the spatial_plot function with the ax object + # Call the spatial_plot function with the ax object and ensure layer is None returned_ax_list = spatial_plot( adata=self.adata, spot_size=self.spot_size, @@ -230,11 +228,11 @@ def mock_spatial( feature='Intensity_10', vmin=0, vmax=100, - ax=ax + ax=ax, + layer=None # Explicitly passing None to layer ) - # Assert that the spatial_plot function returned a list - # containing an Axes object with the same properties + # Assert that the spatial_plot function returned a list containing an Axes object with the same properties returned_ax = returned_ax_list[0] self.assertEqual(returned_ax.get_title(), ax.get_title()) self.assertEqual(returned_ax.get_xlabel(), ax.get_xlabel()) @@ -243,6 +241,7 @@ def mock_spatial( # Restore the original spatial function del spatial_plot.__globals__['sc.pl.spatial'] + def test_spatial_plot_combos_feature(self): # Define the parameter combinations to test spot_sizes = [10, 20] diff --git a/tests/test_visualization/test_visualize_2D_scatter.py b/tests/test_visualization/test_visualize_2D_scatter.py index f79b792c..827cdf16 100644 --- a/tests/test_visualization/test_visualize_2D_scatter.py +++ b/tests/test_visualization/test_visualize_2D_scatter.py @@ -91,8 +91,9 @@ def test_continuous_labels(self): figure, axis = visualize_2D_scatter( self.x, self.y, labels=self.labels_continuous ) - # Check if colorbar is present - self.assertIsNotNone(axis.collections[0].colorbar) + # Check if colorbar is present in the figure + colorbar = figure.colorbar(axis.collections[0]) + self.assertIsNotNone(colorbar) def test_equal_aspect_ratio(self): """Test if the plot has an equal aspect ratio."""