diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index ca12421..9e76af5 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -219,30 +219,20 @@ def _save_features_( def _resolve_extractor_name(raw: str) -> ExtractorName: - """ - Resolve an extractor string to a valid ExtractorName. - - Handles: - - exact matches ('gigapath', 'virchow-full') - - versioned strings like 'gigapath-ae23d', 'virchow-full-2025abc' - Raises ValueError if the base name is not recognized. - """ if not raw: raise ValueError("Empty extractor string") name = str(raw).strip().lower() + name = name.replace("_", "-") - # Exact match for e in ExtractorName: if name == e.value.lower(): return e - # Versioned form: '-something' for e in ExtractorName: if name.startswith(e.value.lower() + "-"): return e - # Otherwise fail raise ValueError( f"Unknown extractor '{raw}'. " f"Expected one of {[e.value for e in ExtractorName]} " diff --git a/src/stamp/encoding/encoder/eagle.py b/src/stamp/encoding/encoder/eagle.py index b2fb293..d966c84 100644 --- a/src/stamp/encoding/encoder/eagle.py +++ b/src/stamp/encoding/encoder/eagle.py @@ -1,5 +1,6 @@ import logging import os +from collections import defaultdict, deque from pathlib import Path import numpy as np @@ -59,11 +60,26 @@ def _validate_and_read_features_with_agg( f"Features located in {h5_vir2} are extracted with {extractor}" ) - if feats.shape[0] != agg_feats.shape[0]: - raise ValueError( - f"Number of ctranspath features and virchow2 features do not match:" - f" {feats.shape[0]} != {agg_feats.shape[0]}" - ) + # if feats.shape[0] != agg_feats.shape[0]: + # raise ValueError( + # f"Number of ctranspath features and virchow2 features do not match:" + # f" {feats.shape[0]} != {agg_feats.shape[0]}" + # ) + if not np.allclose(coords.coords_um, agg_coords.coords_um, atol=1e-5, rtol=0): + # Try to fix permutation by aligning virchow2 to ctp coords + try: + agg_feats, aligned_agg_coords = _align_vir2_to_ctp_by_coords( + ref_coords_um=coords.coords_um, + other_coords_um=agg_coords.coords_um, + other_feats=agg_feats, + decimals=5, + ) + agg_coords.coords_um = aligned_agg_coords # optional, for debugging + except ValueError as e: + raise ValueError( + f"Coordinates mismatch between ctranspath and virchow2 features for slide " + f"{slide_name}. Alignment attempt failed: {e}" + ) if not np.allclose(coords.coords_um, agg_coords.coords_um, atol=1e-5, rtol=0): raise ValueError( @@ -144,7 +160,7 @@ def encode_slides_( for tile_feats_filename in (progress := tqdm(os.listdir(feat_dir))): h5_ctp = os.path.join(feat_dir, tile_feats_filename) h5_vir2 = os.path.join(agg_feat_dir, tile_feats_filename) - slide_name: str = Path(tile_feats_filename).stem + slide_name: str = Path(tile_feats_filename).name progress.set_description(slide_name) # skip patient in case feature file already exists @@ -238,3 +254,37 @@ def encode_patients_( self._save_features_( output_path=output_path, feats=patient_embedding, feat_type="patient" ) + + +def _align_vir2_to_ctp_by_coords( + ref_coords_um: np.ndarray, + other_coords_um: np.ndarray, + other_feats: torch.Tensor, + decimals: int = 5, +) -> tuple[torch.Tensor, np.ndarray]: + """Align vir2 features to ctp features based on coordinates.""" + ref = np.round(np.asarray(ref_coords_um, dtype=np.float64), decimals) + oth = np.round(np.asarray(other_coords_um, dtype=np.float64), decimals) + + # coord -> queue(indices) + buckets = defaultdict(deque) + for j, key in enumerate(map(tuple, oth)): + buckets[key].append(j) + + perm = np.empty(ref.shape[0], dtype=np.int64) + for i, key in enumerate(map(tuple, ref)): + if not buckets[key]: + raise ValueError(f"Missing coord in other set: {key}") + perm[i] = buckets[key].popleft() + + # optional: check if other has extras not used + unused = sum(len(q) for q in buckets.values()) + if unused != 0: + raise ValueError(f"virchow2 features contain {unused} extra coords not in ref.") + + perm_t = torch.as_tensor(perm, dtype=torch.long, device=other_feats.device) + # Align features according to the permutation as well ! + aligned_feats = other_feats.index_select(0, perm_t) + aligned_coords = other_coords_um[perm] + print("") + return aligned_feats, aligned_coords