From f24330f037001d26906deb0479beb0b3e4b43775 Mon Sep 17 00:00:00 2001 From: LouisK92 Date: Thu, 7 Aug 2025 22:00:39 +0200 Subject: [PATCH] Add draft for segger --- .../segger/config.vsh.yaml | 101 ++++++ .../segger/script.py | 338 ++++++++++++++++++ 2 files changed, 439 insertions(+) create mode 100644 src/methods_transcript_assignment/segger/config.vsh.yaml create mode 100644 src/methods_transcript_assignment/segger/script.py diff --git a/src/methods_transcript_assignment/segger/config.vsh.yaml b/src/methods_transcript_assignment/segger/config.vsh.yaml new file mode 100644 index 000000000..901282c78 --- /dev/null +++ b/src/methods_transcript_assignment/segger/config.vsh.yaml @@ -0,0 +1,101 @@ +__merge__: /src/api/comp_method_transcript_assignment.yaml + +name: segger +label: "Segger Transcript Assignment" +summary: "Assign transcripts to cells using the Segger method" +description: "Segger is a tool for cell segmentation in single-molecule spatial omics datasets that leverages graph neural networks (GNNs) and heterogeneous graphs." +links: + documentation: "https://elihei2.github.io/segger_dev/" + repository: "https://github.com/EliHei2/segger_dev" +references: + doi: "10.1101/2025.03.14.643160" + +arguments: + - name: --transcripts_key + type: string + description: The key of the transcripts within the points of the spatial data + default: transcripts + - name: --coordinate_system + type: string + description: The key of the pixel space coordinate system within the spatial data + default: global + +# - name: --force_2d +# type: string +# required: false +# description: "Ignores z-column in the data if it is provided" +# direction: input +# default: "false" +# +# - name: --min_molecules_per_cell +# type: integer +# required: false +# description: "Minimal number of molecules per cell" +# direction: input +# default: 50 +# +# - name: --scale +# type: double +# required: false +# description: | +# "Scale parameter, which suggest approximate cell radius for the algorithm. Must be in the same units as +# x and y molecule coordinates. Negative values mean it must be estimated from `min_molecules_per_cell`." +# direction: input +# default: -1.0 +# +# - name: --scale_std +# type: string +# required: false +# description: "Standard deviation of scale across cells relative to `scale`" +# direction: input +# default: "25%" +# +# - name: --n_clusters +# type: integer +# required: false +# description: "Number of molecule clusters, i.e. major cell types." +# direction: input +# default: 4 +# +# - name: --prior_segmentation_confidence +# type: double +# required: false +# description: "Confidence of the prior segmentation" +# direction: input +# default: 0.8 + +resources: + - type: python_script + path: script.py + +engines: +# - type: docker +# image: openproblems/base_python:1 +# __merge__: +# - /src/base/setup_spatialdata_partial.yaml +# setup: +# - type: python +# pypi: [torch, pytorch-lightning, lightning] +# - type: docker +# run: +# - git clone https://github.com/EliHei2/segger_dev.git +# - cd segger_dev && pip install .[cuda12] +engines: + - type: docker + image: danielunyi42/segger_dev:cuda121 + setup: + - type: apt + packages: procps + - type: python + github: + - openproblems-bio/core#subdirectory=packages/python/openproblems + - type: python + packages: + - spatialdata + - type: native + +runners: + - type: executable + - type: nextflow + directives: + label: [ midtime, midcpu, midmem, gpu ] diff --git a/src/methods_transcript_assignment/segger/script.py b/src/methods_transcript_assignment/segger/script.py new file mode 100644 index 000000000..84cf7e976 --- /dev/null +++ b/src/methods_transcript_assignment/segger/script.py @@ -0,0 +1,338 @@ +from pathlib import Path +import dask +import torch +import xarray as xr +import numpy as np +import geopandas as gpd +import spatialdata as sd +from segger.data.parquet.sample import STSampleParquet +from segger.training.segger_data_module import SeggerDataModule +from segger.training.train import LitSegger, Segger +from torch_geometric.nn import to_hetero +from lightning.pytorch import Trainer +#from pytorch_lightning import Trainer +from lightning.pytorch.loggers import CSVLogger +from segger.prediction.predict_parquet import segment, load_model + + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + 'input_ist': 'resources_test/task_ist_preprocessing/mouse_brain_combined/raw_ist.zarr', + 'input_segmentation': 'resources_test/task_ist_preprocessing/mouse_brain_combined/segmentation.zarr', + #'input_segmentation': 'temp/methods/cellpose/segmentation.zarr', + 'transcripts_key': 'transcripts', + 'coordinate_system': 'global', + 'output': './temp/methods/segger/segger_assigned_transcripts.zarr', + + #TODO: Add the interesting parameters: in this regard see: + # - https://elihei2.github.io/segger_dev/notebooks/segger_tutorial/#4-tune-parameters + # - + #'force_2d': 'false', + #'min_molecules_per_cell': 50, + #'scale': -1.0, + #'scale_std': "25%", + #'n_clusters': 4, + #'prior_segmentation_confidence': 0.8, +} +meta = { + 'name': 'segger_transcript_assignment', + 'temp_dir': "./temp/methods/segger", + 'cpus': 4, +} +## VIASH END + +TMP_DIR = Path(meta["temp_dir"] or "/tmp") +TMP_DIR.mkdir(parents=True, exist_ok=True) + + +#NOTE: for datafile preparation we follow +# https://github.com/EliHei2/segger_dev/blob/generic_config/platform_guides/platform_preparation_guide.ipynb and +# https://github.com/EliHei2/segger_dev/blob/main/src/segger/data/parquet/_settings/xenium.yaml + +POLYGON_PARQUET = TMP_DIR / "nucleus_boundaries.parquet" +TRANSCRIPTS_PARQUET = TMP_DIR / "transcripts.parquet" +SEGGER_DATA_DIR = TMP_DIR / 'data_segger' +SEGGER_DATA_DIR.mkdir(parents=True, exist_ok=True) + +MODELS_DIR = TMP_DIR / 'models' +MODELS_DIR.mkdir(parents=True, exist_ok=True) + + +print('Checking if CUDA is available:', flush=True) +print("\t torch.cuda.is_available() = ", torch.cuda.is_available()) +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + +# Load data + +sdata = sd.read_zarr(par['input_ist']) +sdata_segm = sd.read_zarr(par['input_segmentation']) + +###################### +# boundaries.parquet # +####################### + +# Compute boundaries from segmentation +print('Computing boundaries from segmentation', flush=True) +boundaries = sd.to_polygons(sdata_segm['segmentation'])[['geometry']] + +# Bring boundaries into the coodinate system of the transcripts and save as parquet +print('Transforming boundaries to transcripts coordinate system', flush=True) +trans_segm_to_global = sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True)[par['coordinate_system']] +trans_global_to_transcripts = sd.transformations.get_transformation(sdata[par['transcripts_key']], get_all=True)[par['coordinate_system']].inverse() +trans = sequence = sd.transformations.Sequence([trans_segm_to_global, trans_global_to_transcripts]) +boundaries = sd.transform(boundaries, trans, par['coordinate_system']) +boundaries.index.name = "cell_id" +# Transform to dataframe into Xenium parquet format +boundaries_df = boundaries['geometry'].get_coordinates().rename(columns={'x': 'vertex_x', 'y': 'vertex_y'}) +boundaries_df = boundaries_df.reset_index() +boundaries_df['cell_id'] = boundaries_df['cell_id'].astype(str) + "_id" +boundaries_df['cell_id'] = boundaries_df['cell_id'].astype('object') +boundaries_df['vertex_x'] = boundaries_df['vertex_x'].astype(np.float32) +boundaries_df['vertex_y'] = boundaries_df['vertex_y'].astype(np.float32) +boundaries_df.to_parquet(POLYGON_PARQUET) +del boundaries +del boundaries_df + +####################### +# transcripts.parquet # +####################### + +# Map segmentation ids to transcripts (basic assignment) +transcripts_coord_systems = sd.transformations.get_transformation(sdata[par["transcripts_key"]], get_all=True).keys() +assert par['coordinate_system'] in transcripts_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data." +segmentation_coord_systems = sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True).keys() +assert par['coordinate_system'] in segmentation_coord_systems, f"Coordinate system '{par['coordinate_system']}' not found in input data." + +print('Transforming transcripts coordinates', flush=True) +transcripts = sd.transform(sdata[par['transcripts_key']], to_coordinate_system=par['coordinate_system']) + +# In case of a translation transformation of the segmentation (e.g. crop of the data), we need to adjust the transcript coordinates +trans = sd.transformations.get_transformation(sdata_segm["segmentation"], get_all=True)[par['coordinate_system']].inverse() +transcripts = sd.transform(transcripts, trans, par['coordinate_system']) + +print('Assigning transcripts to cell ids', flush=True) +y_coords = transcripts.y.compute().to_numpy(dtype=np.int64) +x_coords = transcripts.x.compute().to_numpy(dtype=np.int64) +if isinstance(sdata_segm["segmentation"], xr.DataTree): + label_image = sdata_segm["segmentation"]["scale0"].image.to_numpy() +else: + label_image = sdata_segm["segmentation"].to_numpy() +cell_id_dask_series = dask.dataframe.from_dask_array( + dask.array.from_array( + label_image[y_coords, x_coords], chunks=tuple(sdata[par['transcripts_key']].map_partitions(len).compute()) + ), + index=sdata[par['transcripts_key']].index +) +transcripts_original = sdata[par['transcripts_key']] +transcripts_original['cell_id'] = cell_id_dask_series +transcripts_original["overlaps_nucleus"] = (transcripts_original['cell_id'] != 0) +#transcripts_original["overlaps_nucleus"] = np.random.randint(0, 2, size=len(transcripts_original)) #TODO: remove this line +transcripts_original["transcript_id"] = transcripts_original.index +#TODO: The qv columns of Xenium should not be needed at the end. +# See the issue at https://github.com/EliHei2/segger_dev/issues/125 +#transcripts_original["qv"] = 1.0 + +# Generate the same dtypes as in Xenium parquet files +#transcripts_original['cell_id'] = transcripts_original['cell_id'].astype('object') # NOTE: this could probably work, but just to be sure we convert to strings +#transcripts_original['cell_id'] = transcripts_original['cell_id'].replace('0', 'UNASSIGNED') +transcripts_original['cell_id'] = transcripts_original['cell_id'].astype(str) + "_id" +transcripts_original['cell_id'] = transcripts_original['cell_id'].replace('0_id', 'UNASSIGNED') +transcripts_original['overlaps_nucleus'] = transcripts_original['overlaps_nucleus'].astype(np.uint8) +#transcripts_original['qv'] = transcripts_original['qv'].astype(np.float32) +transcripts_original['transcript_id'] = transcripts_original['transcript_id'].astype(np.uint64) + +# Rename columns to match Xenium parquet files +rename_cols = {'x': 'x_location', 'y': 'y_location', 'z': 'z_location', 'feature_name': 'feature_name'} +transcripts_original = transcripts_original.rename(columns=rename_cols) +cols = list(rename_cols.values()) + ['cell_id', 'transcript_id', 'overlaps_nucleus'] #'qv', + +# Convert to pandas dataframe and convert remaining dtypes to match Xenium parquet files +df_transcripts = transcripts_original[cols].compute() +df_transcripts['cell_id'] = df_transcripts['cell_id'].astype('object') # Didn't manage to convert to dtype 'object' above +df_transcripts['feature_name'] = df_transcripts['feature_name'].astype('object') + +# Save transcripts to parquet +df_transcripts.to_parquet(TRANSCRIPTS_PARQUET) + +del df_transcripts +del transcripts_original +del transcripts +del y_coords +del x_coords +del label_image +del cell_id_dask_series + + + +###################### +# Segger Data loader # +###################### + +#NOTE: This was just for debugging purposes with Xenium data, can be deleted +#TMP_DIR2 = Path("temp/datasets/Xenium_V1_hSkin_nondiseased_section_1_FFPE_outs") +#SEGGER_DATA_DIR2 = TMP_DIR / 'data_segger2' +#sample = STSampleParquet(base_dir=TMP_DIR2, n_workers=4, sample_type='xenium') + +print('Preparing Segger data', flush=True) +sample = STSampleParquet(base_dir=TMP_DIR, n_workers=4, sample_type='xenium') + +sample.save( + data_dir=SEGGER_DATA_DIR, + k_bd=3, + dist_bd=15.0, + k_tx=3, + dist_tx=5.0, + tile_width=120, + tile_height=120, + neg_sampling_ratio=5.0, + frac=1.0, + val_prob=0.1, + test_prob=0.2, +) + + +#TODO: optionally add scRNA-seq based embedding https://elihei2.github.io/segger_dev/notebooks/segger_tutorial/#12-using-custom-gene-embeddings +# note that the arguments is_token_based and num_tx_tokens need to change in case of such embedding (see below) + + +###################### +# Train Segger model # +###################### + +# Base directory to store Pytorch Lightning models + +# Initialize the Lightning data module +print('Initializing Segger data module', flush=True) +dm = SeggerDataModule( + data_dir=SEGGER_DATA_DIR, + batch_size=2, + num_workers=meta["cpus"] or 1, +) +dm.setup() + +#is_token_based = True +num_tx_tokens = 500 + +# If you use custom (scRNA-seq based) gene embeddings, use the following two lines instead: +# is_token_based = False +# num_tx_tokens = dm.train[0].x_dict["tx"].shape[1] # Set the number of tokens to the number of genes + +# Initialize the Lightning model +#NOTE: This part throws an error: +# deviates from the current documentation, we follow https://github.com/EliHei2/segger_dev/blob/main/docs/notebooks/segger_tutorial.ipynb +print('Initializing Segger model', flush=True) +###ls = LitSegger( +### #is_token_based = is_token_based, +### #num_node_features = {"tx": num_tx_tokens, "bd": num_bd_features}, +### #init_emb=8, +### #hidden_channels=64, +### out_channels=16, +### heads=4, +### num_mid_layers=1, +### aggr='sum', +###) + +model = Segger( + # is_token_based=is_token_based, + num_tx_tokens=num_tx_tokens, + init_emb=8, + hidden_channels=64, + out_channels=16, + heads=4, + num_mid_layers=3, +) +model = to_hetero(model, (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")]), aggr="sum") + +batch = dm.train[0] +model.forward(batch.x_dict, batch.edge_index_dict) +# Wrap the model in LitSegger +ls = LitSegger(model=model) + + +# Initialize the Lightning trainer +print('Training Segger model', flush=True) +trainer = Trainer( + accelerator=DEVICE, + strategy='auto', + precision='16-mixed', + devices=1, # set higher number if more gpus are available + max_epochs=2,#100, + default_root_dir=MODELS_DIR, + logger=CSVLogger(MODELS_DIR), +) + +# Fit model +trainer.fit( + model=ls, + datamodule=dm +) + +# For debugging: +""" +# Evaluate results +model_version = 0 # 'v_num' from training output above +model_path = MODELS_DIR / 'lightning_logs' / f'version_{model_version}' +metrics = pd.read_csv(model_path / 'metrics.csv', index_col=1) + +fig, ax = plt.subplots(1,1, figsize=(2,2)) + +for col in metrics.columns.difference(['epoch']): + metric = metrics[col].dropna() + ax.plot(metric.index, metric.values, label=col) + +ax.legend(loc=(1, 0.33)) +ax.set_ylim(0, 1) +ax.set_xlabel('Step') +""" + +############# +# Inference # +############# + +print('Loading Segger model', flush=True) +model_version = 0 +model_path = MODELS_DIR / "lightning_logs" / f"version_{model_version}" +model = load_model(model_path / "checkpoints") + +#NOTE: Parameters selected according to https://github.com/EliHei2/segger_dev/issues/126#issuecomment-3160389883 +receptive_field = {'k_bd': 4, 'dist_bd': 7.5, 'k_tx': 15, 'dist_tx': 3} + +print('Running Segger inference', flush=True) +#NOTE: Running into this error (also with the latest segger version, 07.08.2025): https://github.com/EliHei2/segger_dev/issues/126 +segment( + model, + dm, + score_cut = .75, + save_dir=TMP_DIR, + seg_tag='segger_output', + transcript_file=TRANSCRIPTS_PARQUET, + receptive_field=receptive_field, + min_transcripts=5, + cell_id_col='segger_cell_id', + use_cc=False, + knn_method='kd_tree', + verbose=True, +) + +########################################################## +# Load Segger results and save transcripts with cell ids # +########################################################## + +#TODO: Look at other transcript assignment methods for these last steps + +#TODO: Check if this lines are correct +# +#sdata = sd.read_zarr(par['input_ist']) +#segger_df = gpd.read_parquet(TMP_DIR / "segger_output.parquet") +# +#transcripts = sdata[par['transcripts_key']] +#assert transcripts.index == segger_df.index, "Transcripts and Segger results have different indices" +#transcripts['cell_id'] = segger_df['segger_cell_id'] + + +#TODO: Save transcripts with cell ids + + +