Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions src/stamp/encoding/encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,30 +219,20 @@ def _save_features_(


def _resolve_extractor_name(raw: str) -> ExtractorName:
"""
Resolve an extractor string to a valid ExtractorName.

Handles:
- exact matches ('gigapath', 'virchow-full')
- versioned strings like 'gigapath-ae23d', 'virchow-full-2025abc'
Raises ValueError if the base name is not recognized.
"""
if not raw:
raise ValueError("Empty extractor string")

name = str(raw).strip().lower()
name = name.replace("_", "-")

# Exact match
for e in ExtractorName:
if name == e.value.lower():
return e

# Versioned form: '<enum-value>-something'
for e in ExtractorName:
if name.startswith(e.value.lower() + "-"):
return e

# Otherwise fail
raise ValueError(
f"Unknown extractor '{raw}'. "
f"Expected one of {[e.value for e in ExtractorName]} "
Expand Down
62 changes: 56 additions & 6 deletions src/stamp/encoding/encoder/eagle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from collections import defaultdict, deque
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -59,11 +60,26 @@ def _validate_and_read_features_with_agg(
f"Features located in {h5_vir2} are extracted with {extractor}"
)

if feats.shape[0] != agg_feats.shape[0]:
raise ValueError(
f"Number of ctranspath features and virchow2 features do not match:"
f" {feats.shape[0]} != {agg_feats.shape[0]}"
)
# if feats.shape[0] != agg_feats.shape[0]:
# raise ValueError(
# f"Number of ctranspath features and virchow2 features do not match:"
# f" {feats.shape[0]} != {agg_feats.shape[0]}"
# )
if not np.allclose(coords.coords_um, agg_coords.coords_um, atol=1e-5, rtol=0):
# Try to fix permutation by aligning virchow2 to ctp coords
try:
agg_feats, aligned_agg_coords = _align_vir2_to_ctp_by_coords(
ref_coords_um=coords.coords_um,
other_coords_um=agg_coords.coords_um,
other_feats=agg_feats,
decimals=5,
)
agg_coords.coords_um = aligned_agg_coords # optional, for debugging
except ValueError as e:
raise ValueError(
f"Coordinates mismatch between ctranspath and virchow2 features for slide "
f"{slide_name}. Alignment attempt failed: {e}"
)

if not np.allclose(coords.coords_um, agg_coords.coords_um, atol=1e-5, rtol=0):
raise ValueError(
Expand Down Expand Up @@ -144,7 +160,7 @@ def encode_slides_(
for tile_feats_filename in (progress := tqdm(os.listdir(feat_dir))):
h5_ctp = os.path.join(feat_dir, tile_feats_filename)
h5_vir2 = os.path.join(agg_feat_dir, tile_feats_filename)
slide_name: str = Path(tile_feats_filename).stem
slide_name: str = Path(tile_feats_filename).name
progress.set_description(slide_name)

# skip patient in case feature file already exists
Expand Down Expand Up @@ -238,3 +254,37 @@ def encode_patients_(
self._save_features_(
output_path=output_path, feats=patient_embedding, feat_type="patient"
)


def _align_vir2_to_ctp_by_coords(
ref_coords_um: np.ndarray,
other_coords_um: np.ndarray,
other_feats: torch.Tensor,
decimals: int = 5,
) -> tuple[torch.Tensor, np.ndarray]:
"""Align vir2 features to ctp features based on coordinates."""
ref = np.round(np.asarray(ref_coords_um, dtype=np.float64), decimals)
oth = np.round(np.asarray(other_coords_um, dtype=np.float64), decimals)

# coord -> queue(indices)
buckets = defaultdict(deque)
for j, key in enumerate(map(tuple, oth)):
buckets[key].append(j)

perm = np.empty(ref.shape[0], dtype=np.int64)
for i, key in enumerate(map(tuple, ref)):
if not buckets[key]:
raise ValueError(f"Missing coord in other set: {key}")
perm[i] = buckets[key].popleft()

# optional: check if other has extras not used
unused = sum(len(q) for q in buckets.values())
if unused != 0:
raise ValueError(f"virchow2 features contain {unused} extra coords not in ref.")

perm_t = torch.as_tensor(perm, dtype=torch.long, device=other_feats.device)
# Align features according to the permutation as well !
aligned_feats = other_feats.index_select(0, perm_t)
aligned_coords = other_coords_um[perm]
print("")
return aligned_feats, aligned_coords