Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 246 additions & 6 deletions src/spac/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from matplotlib.colors import ListedColormap, BoundaryNorm
from spac.utils import check_table, check_annotation
from spac.utils import check_feature, annotation_category_relations
from spac.utils import check_label
from spac.utils import check_label, check_list_in_list
from spac.utils import get_defined_color_map
from spac.utils import compute_boxplot_metrics
from functools import partial
Expand All @@ -35,7 +35,7 @@ def visualize_2D_scatter(
x, y, labels=None, point_size=None, theme=None,
ax=None, annotate_centers=False,
x_axis_title='Component 1', y_axis_title='Component 2', plot_title=None,
color_representation=None, **kwargs
color_representation=None, color_map=None, **kwargs
):
"""
Visualize 2D data using plt.scatter.
Expand Down Expand Up @@ -65,6 +65,8 @@ def visualize_2D_scatter(
Title for the plot.
color_representation : str, optional
Description of what the colors represent.
color_map : dictionary, optional
Dictionary containing colors for label annotations.
**kwargs
Additional keyword arguments passed to plt.scatter.

Expand All @@ -83,6 +85,8 @@ def visualize_2D_scatter(
raise ValueError("x and y must have the same length.")
if labels is not None and len(labels) != len(x):
raise ValueError("Labels length should match x and y length.")
if color_map is not None and not isinstance(color_map, dict):
raise ValueError("`color_map` must be a dict mapping label→color.")

# Define color themes
themes = {
Expand Down Expand Up @@ -141,15 +145,17 @@ def visualize_2D_scatter(
cmap2 = plt.get_cmap('tab20b')
cmap3 = plt.get_cmap('tab20c')
colors = cmap1.colors + cmap2.colors + cmap3.colors

# Use the number of unique clusters to set the colormap length
cmap = ListedColormap(colors[:len(unique_clusters)])
cluster_to_color = color_map if color_map is not None else {
str(cluster): colors[i % len(colors)]
for i, cluster in enumerate(unique_clusters)
}

for idx, cluster in enumerate(unique_clusters):
mask = np.array(labels) == cluster
color = cluster_to_color.get(str(cluster), 'gray')
ax.scatter(
x[mask], y[mask],
color=cmap(idx),
color=color,
label=cluster,
s=point_size
)
Expand Down Expand Up @@ -196,6 +202,240 @@ def visualize_2D_scatter(
return fig, ax


def embedded_scatter_plot(
adata,
method=None,
annotation=None,
feature=None,
layer=None,
ax=None,
associated_table=None,
spot_size=20,
alpha=0.5,
vmin=-999,
vmax=-999,
color_map=None,
**kwargs):
"""
Visualize scatter plot in PCA, t-SNE, UMAP, spatial or associated table.

Parameters
----------
adata : anndata.AnnData
The AnnData object with coordinates precomputed by the 'tsne' or 'UMAP'
function and stored in 'adata.obsm["X_tsne"]' or 'adata.obsm["X_umap"]'
method : str, optional (default: None)
Visualization method specifying the coordinate system to plot.
Choose from {'tsne', 'umap', 'pca', 'spatial'}.
annotation : str, optional
The name of the column in `adata.obs` to use for coloring
the scatter plot points based on cell annotations.
feature : str, optional
The name of the gene or feature in `adata.var_names` to use
for coloring the scatter plot points based on feature expression.
layer : str, optional
The name of the data layer in `adata.layers` to use for visualization.
If None, the main data matrix `adata.X` is used.
ax : matplotlib.axes.Axes, optional (default: None)
A matplotlib axes object to plot on.
If not provided, a new figure and axes will be created.
associated_table : str, optional (default: None)
Name of the key in `obsm` that contains the numpy array. Takes
precedence over `method`
color_map : str, optional (default: None)
Name of the key in adata.uns that contains color-mapping for
the plot
**kwargs
Parameters passed to visualize_2D_scatter function,
including point_size.

Returns
-------
fig : matplotlib.figure.Figure
The created figure for the plot.
ax : matplotlib.axes.Axes
The axes of the plot.
"""

# Check if both annotation and feature are specified, raise error if so
if annotation and feature:
raise ValueError(
"Please specify either an annotation or a feature for coloring, "
"not both.")

# Use utility functions for input validation
if layer:
check_table(adata, tables=layer)
if annotation:
check_annotation(adata, annotations=annotation)
if feature:
check_feature(adata, features=[feature])
color_mapping = None
if color_map is not None:
color_mapping = get_defined_color_map(
adata,
defined_color_map=color_map,
annotations=annotation
)

# Validate the method and check if the necessary data exists in adata.obsm
if associated_table is None:
valid_methods = ['tsne', 'umap', 'pca', 'spatial']
check_list_in_list(input=method, input_name="method",
input_type="method",
target_list=valid_methods,
need_exist=True
)
if method == "spatial":
key = "spatial"
else:
key = f'X_{method}'
if key not in adata.obsm.keys():
raise ValueError(
f"{key} coordinates not found in adata.obsm. "
f"Please run {method.upper()} before calling this function."
)

else:
check_table(
adata=adata,
tables=associated_table,
should_exist=True,
associated_table=True
)

associated_table_shape = adata.obsm[associated_table].shape
if associated_table_shape[1] != 2:
raise ValueError(
f'The associated table:"{associated_table}" does not have'
f' two dimensions. It shape is:"{associated_table_shape}"'
)
key = associated_table

err_msg_feat_annotation_non = "Both annotation and feature are None, " + \
"please provide single input."
err_msg_spot_size = "The 'spot_size' parameter must be an integer, " + \
f"got {str(type(spot_size))}"
err_msg_alpha_type = "The 'alpha' parameter must be a float," + \
f"got {str(type(alpha))}"
err_msg_alpha_value = "The 'alpha' parameter must be between " + \
f"0 and 1 (inclusive), got {str(alpha)}"
err_msg_vmin = "The 'vmin' parameter must be a float or an int, " + \
f"got {str(type(vmin))}"
err_msg_vmax = "The 'vmax' parameter must be a float or an int, " + \
f"got {str(type(vmax))}"
err_msg_ax = "The 'ax' parameter must be an instance " + \
f"of matplotlib.axes.Axes, got {str(type(ax))}"

if adata is None:
raise ValueError("The input dataset must not be None.")

if not isinstance(adata, anndata.AnnData):
err_msg_adata = "The 'adata' parameter must be an " + \
f"instance of anndata.AnnData, got {str(type(adata))}."
raise ValueError(err_msg_adata)

if key == "spatial":
if annotation is None and feature is None:
raise ValueError(err_msg_feat_annotation_non)

if 'spatial' not in adata.obsm_keys():
err_msg = "Spatial coordinates not found in the 'obsm' attribute."
raise ValueError(err_msg)

# Extract feature name
if not isinstance(spot_size, int):
raise ValueError(err_msg_spot_size)

if not isinstance(alpha, float):
raise ValueError(err_msg_alpha_type)

if not (0 <= alpha <= 1):
raise ValueError(err_msg_alpha_value)

if vmin != -999 and not (
isinstance(vmin, float) or isinstance(vmin, int)
):
raise ValueError(err_msg_vmin)

if vmax != -999 and not (
isinstance(vmax, float) or isinstance(vmax, int)
):
raise ValueError(err_msg_vmax)

if ax is not None and not isinstance(ax, plt.Axes):
raise ValueError(err_msg_ax)

print(f'Running visualization using the coordinates: "{key}"')

# Extract the 2D coordinates
x, y = adata.obsm[key].T

# Determine coloring scheme
if color_mapping is None:
if annotation:
color_values = adata.obs[annotation].astype('category').values
color_representation = annotation
elif feature:
data_src = adata.layers[layer] if layer else adata.X
color_values = data_src[:, adata.var_names == feature].squeeze()
color_representation = feature
else:
color_values = None
color_representation = None
else:
color_values = adata.obs[annotation].astype('category').values
color_representation = annotation

# Set axis titles based on method and color representation
if method == 'tsne':
x_axis_title = 't-SNE 1'
y_axis_title = 't-SNE 2'
plot_title = f'TSNE-{color_representation}'
elif method == 'pca':
x_axis_title = 'PCA 1'
y_axis_title = 'PCA 2'
plot_title = f'PCA-{color_representation}'
elif method == 'umap':
x_axis_title = 'UMAP 1'
y_axis_title = 'UMAP 2'
plot_title = f'UMAP-{color_representation}'
elif method == 'spatial':
x_axis_title = 'SPATIAL 1'
y_axis_title = 'SPATIAL 2'
plot_title = f'SPATIAL-{color_representation}'

else:
x_axis_title = f'{associated_table} 1'
y_axis_title = f'{associated_table} 2'
plot_title = f'{associated_table}-{color_representation}'

# Remove conflicting keys from kwargs
kwargs.pop('x_axis_title', None)
kwargs.pop('y_axis_title', None)
kwargs.pop('plot_title', None)
kwargs.pop('color_representation', None)

# Set Min and Max in kwargs
kwargs['vmin'] = vmin
kwargs['vmax'] = vmax

fig, ax = visualize_2D_scatter(
x=x,
y=y,
ax=ax,
labels=color_values,
x_axis_title=x_axis_title,
y_axis_title=y_axis_title,
plot_title=plot_title,
color_representation=color_representation,
color_map=color_mapping,
**kwargs
)

return fig, ax


def dimensionality_reduction_plot(
adata,
method=None,
Expand Down
Loading