Skip to content
Open
245 changes: 165 additions & 80 deletions src/spac/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')


def visualize_2D_scatter(
x, y, labels=None, point_size=None, theme=None,
x, y,labels=None, point_size=None, theme=None,
ax=None, annotate_centers=False,
x_axis_title='Component 1', y_axis_title='Component 2', plot_title=None,
color_representation=None, **kwargs
color_representation=None, color_map=None, **kwargs
):
"""
Visualize 2D data using plt.scatter.
Expand All @@ -39,10 +38,7 @@ def visualize_2D_scatter(
point_size : float, optional
Size of the points. If None, it will be automatically determined.
theme : str, optional
Color theme for the plot. Defaults to 'viridis' if theme not
recognized. For a list of supported themes, refer to Matplotlib's
colormap documentation:
https://matplotlib.org/stable/tutorials/colors/colormaps.html
Color theme for the plot. Defaults to 'viridis' if theme is not recognized.
ax : matplotlib.axes.Axes, optional (default: None)
Matplotlib axis object. If None, a new one is created.
annotate_centers : bool, optional (default: False)
Expand All @@ -55,6 +51,8 @@ def visualize_2D_scatter(
Title for the plot.
color_representation : str, optional
Description of what the colors represent.
color_map : str, optional
Provides color dictionary for scatterplot
**kwargs
Additional keyword arguments passed to plt.scatter.

Expand All @@ -71,8 +69,7 @@ def visualize_2D_scatter(
raise ValueError("x and y must be array-like.")
if len(x) != len(y):
raise ValueError("x and y must have the same length.")
if labels is not None and len(labels) != len(x):
raise ValueError("Labels length should match x and y length.")


# Define color themes
themes = {
Expand All @@ -88,18 +85,19 @@ def visualize_2D_scatter(
}

if theme and theme not in themes:
error_msg = (
f"Theme '{theme}' not recognized. Please use a valid theme."
)
error_msg = f"Theme '{theme}' not recognized. Please use a valid theme."
raise ValueError(error_msg)

cmap = themes.get(theme, plt.get_cmap('viridis'))

# Determine point size
num_points = len(x)
if point_size is None:
point_size = 5000 / num_points

# Get figure size and fontsize from kwargs or set defaults

# Get figure size from kwargs or set defaults

fig_width = kwargs.get('fig_width', 10)
fig_height = kwargs.get('fig_height', 8)
fontsize = kwargs.get('fontsize', 12)
Expand All @@ -109,13 +107,64 @@ def visualize_2D_scatter(
else:
fig = ax.figure

# Plotting logic
if labels is not None:
# Check if labels are categorical
if pd.api.types.is_categorical_dtype(labels):
if labels is not None and len(labels) != len(x):
raise ValueError("Labels length should match x and y length.")

# Determine how to access the categories based on
# the type of 'labels'
#default
scatter = ax.scatter(x, y, s=point_size, c='gray', **kwargs)
if color_map is not None and labels is not None:
# Check if the color_representation exists in adata.obs
colors = [color_map.get(label, 'gray') for label in labels]
scatter = ax.scatter(x, y, c=colors, s=point_size, **kwargs)

# Create legend for pin_color mapping
if isinstance(labels[0], str): # Categorical data
# Create legend handles for each unique label
unique_labels = set(labels)
handles = [
plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_map.get(label, 'gray'), markersize=10)
for label in unique_labels
]
legend_labels = list(unique_labels)

# Add the legend to the plot
ax.legend(
handles,
legend_labels,
title="Categories",
bbox_to_anchor=(1.05, 1),
loc='upper left'
)

# Annotate cluster centers if required (for categorical data)
if annotate_centers:
unique_labels = set(labels)
for cluster in unique_labels:
# Mask the data for the current cluster
mask = [label == cluster for label in labels]
cluster_x = [x[i] for i in range(len(x)) if mask[i]]
cluster_y = [y[i] for i in range(len(y)) if mask[i]]

# Compute cluster center
center_x = np.mean(cluster_x)
center_y = np.mean(cluster_y)

# Annotate the center
ax.text(
center_x, center_y, cluster,
fontsize=9, ha='center', va='center',
bbox=dict(boxstyle='round,pad=0.3', edgecolor='gray', facecolor='white', alpha=0.7)
)

elif labels is not None:
# Check if labels are continuous (numeric)
if pd.api.types.is_numeric_dtype(labels):
scatter = ax.scatter(x, y, c=labels, cmap=cmap, s=point_size, **kwargs)
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Label Intensity')

# Check if labels are categorical
elif pd.api.types.is_categorical_dtype(labels):
if isinstance(labels, pd.Series):
unique_clusters = labels.cat.categories
elif isinstance(labels, pd.Categorical):
Expand All @@ -125,16 +174,7 @@ def visualize_2D_scatter(
"Expected labels to be of type Series[Categorical] or "
"Categorical."
)

# Combine colors from multiple colormaps
cmap1 = plt.get_cmap('tab20')
cmap2 = plt.get_cmap('tab20b')
cmap3 = plt.get_cmap('tab20c')
colors = cmap1.colors + cmap2.colors + cmap3.colors

# Use the number of unique clusters to set the colormap length
cmap = ListedColormap(colors[:len(unique_clusters)])

cmap = plt.get_cmap('tab20', len(unique_clusters))
for idx, cluster in enumerate(unique_clusters):
mask = np.array(labels) == cluster
ax.scatter(
Expand All @@ -145,35 +185,28 @@ def visualize_2D_scatter(
)
print(f"Cluster: {cluster}, Points: {np.sum(mask)}")

# Annotate cluster centers if required
if annotate_centers:
center_x = np.mean(x[mask])
center_y = np.mean(y[mask])
ax.text(
center_x, center_y, cluster,
fontsize=fontsize, ha='center', va='center'
)
# Create a custom legend with color representation
ax.legend(
loc='best',
bbox_to_anchor=(1.25, 1), # Adjusting position
title=f"Color represents: {color_representation}"
)

ax.legend(bbox_to_anchor=(1.05, 1),title=color_representation, loc='upper left', borderaxespad=0.)
else:
# If labels are continuous
scatter = ax.scatter(
x, y, c=labels, cmap=cmap, s=point_size, **kwargs
)
plt.colorbar(scatter, ax=ax)
if color_representation is not None:
ax.set_title(
f"{plot_title}\nColor represents: {color_representation}"
)
else:
scatter = ax.scatter(x, y, c='gray', s=point_size, **kwargs)
scatter = ax.scatter(x, y, c=labels, cmap=cmap, s=point_size, **kwargs)
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('Label Intensity')

# Equal aspect ratio for the axes
ax.set_aspect('equal', 'datalim')
# Set axis labels and title
ax.set_xlabel(x_axis_title)
ax.set_ylabel(y_axis_title)
if plot_title is not None:
ax.set_title(plot_title)

ax.set_aspect('equal', adjustable='box')

# Set axis labels
ax.set_xlabel(x_axis_title)
Expand Down Expand Up @@ -970,8 +1003,10 @@ def spatial_plot(
feature=None,
layer=None,
ax=None,
pin_color_rules=None,
**kwargs
):

"""
Generate the spatial plot of selected features
Parameters
Expand Down Expand Up @@ -1000,6 +1035,8 @@ def spatial_plot(
The matplotlib Axes containing the analysis plots.
The returned ax is the passed ax or new ax created.
Only works if plotting a single component.
pin_color_rules : str, optional
Dictionary name in `adata.uns` for custom colors.
**kwargs
Arguments to pass to matplotlib.pyplot.scatter()
Returns
Expand All @@ -1008,27 +1045,27 @@ def spatial_plot(
"""

err_msg_layer = "The 'layer' parameter must be a string, " + \
f"got {str(type(layer))}"
f"got {str(type(layer))}."
err_msg_feature = "The 'feature' parameter must be a string, " + \
f"got {str(type(feature))}"
f"got {str(type(feature))}."
err_msg_annotation = "The 'annotation' parameter must be a string, " + \
f"got {str(type(annotation))}"
f"got {str(type(annotation))}."
err_msg_feat_annotation_coe = "Both annotation and feature are passed, " +\
"please provide sinle input."
"please provide single input."
err_msg_feat_annotation_non = "Both annotation and feature are None, " + \
"please provide single input."
err_msg_spot_size = "The 'spot_size' parameter must be an integer, " + \
f"got {str(type(spot_size))}"
f"got {str(type(spot_size))}."
err_msg_alpha_type = "The 'alpha' parameter must be a float," + \
f"got {str(type(alpha))}"
f"got {str(type(alpha))}."
err_msg_alpha_value = "The 'alpha' parameter must be between " + \
f"0 and 1 (inclusive), got {str(alpha)}"
f"0 and 1 (inclusive), got {str(alpha)}."
err_msg_vmin = "The 'vmin' parameter must be a float or an int, " + \
f"got {str(type(vmin))}"
f"got {str(type(vmin))}."
err_msg_vmax = "The 'vmax' parameter must be a float or an int, " + \
f"got {str(type(vmax))}"
f"got {str(type(vmax))}."
err_msg_ax = "The 'ax' parameter must be an instance " + \
f"of matplotlib.axes.Axes, got {str(type(ax))}"
f"of matplotlib.axes.Axes, got {str(type(ax))}."

if adata is None:
raise ValueError("The input dataset must not be None.")
Expand All @@ -1042,7 +1079,7 @@ def spatial_plot(
raise ValueError(err_msg_layer)

if layer is not None and layer not in adata.layers.keys():
err_msg_layer_exist = f"Layer {layer} does not exists, " + \
err_msg_layer_exist = f"Layer {layer} does not exist, " + \
f"available layers are {str(adata.layers.keys())}"
raise ValueError(err_msg_layer_exist)

Expand All @@ -1068,7 +1105,7 @@ def spatial_plot(

if annotation is not None and annotation not in annotation_names:
error_text = f'The annotation "{annotation}"' + \
'not found in the dataset.' + \
' not found in the dataset.' + \
f" Existing annotations are: {annotation_names_str}"
raise ValueError(error_text)

Expand Down Expand Up @@ -1107,38 +1144,86 @@ def spatial_plot(
if ax is not None and not isinstance(ax, plt.Axes):
raise ValueError(err_msg_ax)

feature_names = adata.var_names.tolist()
annotation_names = adata.obs.columns.tolist()

if feature is not None:
# Get the integer index of the feature
try:
feature_index = adata.var_names.get_loc(feature) # Ensure it's an integer
except KeyError:
raise ValueError(f"Feature '{feature}' not found in the dataset.")

if layer is None:
# If layer is None, use adata.X by default
feature_values = adata.X[:, feature_index]
elif isinstance(layer, str):
# If layer is a string, use it to access the correct layer in adata.layers
if layer not in adata.layers:
raise ValueError(f"Layer '{layer}' not found in adata.layers. Available layers: {adata.layers.keys()}")
feature_values = adata.layers[layer][:, feature_index]
elif isinstance(layer, np.ndarray):
# If layer is a numpy.ndarray, treat it as the data for the feature
feature_values = layer[:, feature_index]
else:
# If layer is neither a string nor numpy.ndarray, raise an error
raise ValueError(f"Expected 'layer' to be a string or numpy.ndarray, but got {type(layer)}.")

feature_index = feature_names.index(feature)
feature_annotation = feature + "spatial_plot"
# Handle vmin and vmax if not specified
if vmin == -999:
vmin = np.min(layer[:, feature_index])
vmin = np.min(feature_values)
if vmax == -999:
vmax = np.max(layer[:, feature_index])
adata.obs[feature_annotation] = layer[:, feature_index]
vmax = np.max(feature_values)

feature_annotation = feature + "_spatial_plot"
adata.obs[feature_annotation] = feature_values.flatten()
color_region = feature_annotation
else:
color_region = annotation
vmin = None
vmax = None

if ax is None:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)

ax = sc.pl.spatial(
adata=adata,
layer=layer,
color=color_region,
spot_size=spot_size,

spatial_coords = adata.obsm['spatial']
x_coords, y_coords = spatial_coords[:, 0], spatial_coords[:, 1]

# Color handling based on the annotation
if pin_color_rules:
color_map = adata.uns.get(pin_color_rules, {})
elif '_spac_colors' in adata.uns:
color_map = adata.uns['_spac_colors']
else:
color_map = {}

# Check if annotation is valid and exists in adata.obs columns
if annotation is not None and annotation not in adata.obs.columns:
error_text = f'The annotation "{annotation}"' + \
' not found in the dataset.' + \
f" Existing annotations are: {', '.join(adata.obs.columns.tolist())}"
raise ValueError(error_text)

# If annotation is provided, use it for color mapping, otherwise default to gray
if annotation is not None:
colors = [color_map.get(label, 'gray') for label in adata.obs[annotation]]
else:
colors = ['gray'] * len(x_coords) # Default color if no annotation is provided

# Create scatter plot
scatter = ax.scatter(
x=x_coords,
y=y_coords,
c=colors,
s=spot_size,
alpha=alpha,
vmin=vmin,
vmax=vmax,
ax=ax,
show=False,
**kwargs)
**kwargs
)

# Set color limits for feature-based visualization
if feature is not None:
scatter.set_clim(vmin, vmax)

return [ax]

return ax


def boxplot(adata, annotation=None, second_annotation=None, layer=None,
Expand Down
Loading