diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index 4ab8416f..f5d8a62a 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -86,6 +86,7 @@ def _run_cli(args: argparse.Namespace) -> None: tile_size_um=config.preprocessing.tile_size_um, tile_size_px=config.preprocessing.tile_size_px, extractor=config.preprocessing.extractor, + tile_extractor=config.preprocessing.tile_extractor, max_workers=config.preprocessing.max_workers, device=config.preprocessing.device, default_slide_mpp=config.preprocessing.default_slide_mpp, diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 796140a5..769bb355 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -4,7 +4,7 @@ preprocessing: # Extractor to use for feature extractor. Possible options are "ctranspath", # "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom", # "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow", - # "virchow-full", "musk", "mstar", "plip" + # "virchow-full", "musk", "mstar", "plip", "ticon" # Some of them require requesting access to the respective authors beforehand. extractor: "chief-ctranspath" @@ -12,6 +12,9 @@ preprocessing: device: "cuda" # Optional settings: + # if "ticon" is selected, specify model to enhance + # e.g. "h-optimus-1, "virchow2","conch1_5", "uni2", "gigapath" + tile_extractor: "h-optimus-1" # Having a cache dir will speed up extracting features multiple times, # e.g. with different feature extractors. Optional. @@ -249,7 +252,7 @@ heatmaps: slide_encoding: # Encoder to use for slide encoding. Possible options are "cobra", - # "eagle", "titan", "gigapath", "chief", "prism", "madeleine". + # "eagle", "titan", "gigapath", "chief", "prism", "madeleine", "ticon". encoder: "chief" # Directory to save the output files. diff --git a/src/stamp/encoding/__init__.py b/src/stamp/encoding/__init__.py index 9cb873bb..a2ad916a 100644 --- a/src/stamp/encoding/__init__.py +++ b/src/stamp/encoding/__init__.py @@ -69,6 +69,11 @@ def init_slide_encoder_( selected_encoder: Encoder = Prism() + case EncoderName.TICON: + from stamp.encoding.encoder.ticon_encoder import TiconEncoder + + selected_encoder: Encoder = TiconEncoder() + case Encoder(): selected_encoder = encoder @@ -155,6 +160,11 @@ def init_patient_encoder_( selected_encoder: Encoder = Prism() + case EncoderName.TICON: + from stamp.encoding.encoder.ticon_encoder import TiconEncoder + + selected_encoder: Encoder = TiconEncoder() + case Encoder(): selected_encoder = encoder diff --git a/src/stamp/encoding/config.py b/src/stamp/encoding/config.py index 1a2bcba7..c0db4477 100644 --- a/src/stamp/encoding/config.py +++ b/src/stamp/encoding/config.py @@ -14,6 +14,7 @@ class EncoderName(StrEnum): GIGAPATH = "gigapath" MADELEINE = "madeleine" PRISM = "prism" + TICON = "ticon" class SlideEncodingConfig(BaseModel, arbitrary_types_allowed=True): diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index ca124214..62d89793 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -79,11 +79,12 @@ def encode_slides_( tqdm.write(s=str(e)) continue - slide_embedding = self._generate_slide_embedding( - feats, device, coords=coords - ) + slide_embedding = self._generate_slide_embedding(feats, device, **kwargs) self._save_features_( - output_path=output_path, feats=slide_embedding, feat_type="slide" + output_path=output_path, + feats=slide_embedding, + feat_type="slide", + **kwargs, ) def encode_patients_( @@ -133,7 +134,7 @@ def encode_patients_( for _, row in group.iterrows(): slide_filename = row[filename_label] h5_path = os.path.join(feat_dir, slide_filename) - feats, _ = self._validate_and_read_features(h5_path) + feats, coords = self._validate_and_read_features(h5_path) feats_list.append(feats) if not feats_list: @@ -149,7 +150,10 @@ def encode_patients_( @abstractmethod def _generate_slide_embedding( - self, feats: torch.Tensor, device, **kwargs + self, + feats: torch.Tensor, + device, + **kwargs, ) -> np.ndarray: """Generate slide embedding. Must be implemented by subclasses.""" pass @@ -193,7 +197,7 @@ def _read_h5( return feats, coords, _resolve_extractor_name(extractor) def _save_features_( - self, output_path: Path, feats: np.ndarray, feat_type: str + self, output_path: Path, feats: np.ndarray, feat_type: str, **kwargs ) -> None: with ( NamedTemporaryFile(dir=output_path.parent, delete=False) as tmp_h5_file, @@ -201,6 +205,11 @@ def _save_features_( ): try: f["feats"] = feats + f["coords"] = kwargs.get("coords", np.array([])) + if "tile_size_um" in kwargs and kwargs["tile_size_um"] is not None: + f.attrs["tile_size_um"] = float(kwargs["tile_size_um"]) + if "tile_size_px" in kwargs and kwargs["tile_size_px"] is not None: + f.attrs["tile_size_px"] = int(kwargs["tile_size_px"]) f.attrs["version"] = stamp.__version__ f.attrs["encoder"] = str(self.identifier) f.attrs["precision"] = str(self.precision) diff --git a/src/stamp/encoding/encoder/ticon_encoder.py b/src/stamp/encoding/encoder/ticon_encoder.py new file mode 100644 index 00000000..8f700ac9 --- /dev/null +++ b/src/stamp/encoding/encoder/ticon_encoder.py @@ -0,0 +1,758 @@ +""" +TICON Model Architecture and Configuration. + +Shared between "Isolated" and "Contextualized" modes. +Contains all model components, configuration, and utility functions. +Adapted from: + +@misc{belagali2025ticonslideleveltilecontextualizer, + title={TICON: A Slide-Level Tile Contextualizer for Histopathology Representation Learning}, + author={Varun Belagali and Saarthak Kapse and Pierre Marza and Srijan Das and Zilinghan Li and Sofiène Boutaj and Pushpak Pati and Srikar Yellapragada and Tarak Nath Nandi and Ravi K Madduri and Joel Saltz and Prateek Prasanna and Stergios Christodoulidis and Maria Vakalopoulou and Dimitris Samaras}, + year={2025}, + eprint={2512.21331}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2512.21331}, +} +""" + +import logging +import math +import os +from collections.abc import Callable, Mapping +from functools import partial +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from jaxtyping import Float +from torch import Tensor +from torch.nn.attention import SDPBackend, sdpa_kernel +from tqdm import tqdm + +# try: +# from torch.amp.autocast_mode import autocast +# except (ImportError, AttributeError): +# try: +# from torch.cuda.amp import autocast +# except ImportError: +# from torch.amp import autocast # type: ignore +from stamp.cache import get_processing_code_hash +from stamp.encoding.encoder import Encoder, EncoderName +from stamp.modeling.data import CoordsInfo +from stamp.preprocessing.config import ExtractorName +from stamp.types import DeviceLikeType + +_logger = logging.getLogger("stamp") + +# Mapping: ExtractorName -> (ticon_key, embedding_dim) +TILE_EXTRACTOR_TO_TICON: dict[ExtractorName, tuple[ExtractorName, int]] = { + ExtractorName.CONCH1_5: (ExtractorName.CONCH1_5, 768), + ExtractorName.H_OPTIMUS_1: (ExtractorName.H_OPTIMUS_1, 1536), + ExtractorName.UNI2: (ExtractorName.UNI2, 1536), + ExtractorName.GIGAPATH: (ExtractorName.GIGAPATH, 1536), + ExtractorName.VIRCHOW2: (ExtractorName.VIRCHOW2, 1280), +} + +# TICON model configuration +TICON_MODEL_CFG: dict[str, Any] = { + "transformers_kwargs": { + "embed_dim": 1536, + "drop_path_rate": 0.0, + "block_kwargs": { + "attn_kwargs": {"num_heads": 24}, + }, + }, + "encoder_kwargs": {"depth": 6}, + "decoder_kwargs": {"depth": 1}, + "in_dims": [768, 1536, 1536, 1536, 1280], + "tile_encoder_keys": [ + ExtractorName.CONCH1_5, + ExtractorName.H_OPTIMUS_1, + ExtractorName.UNI2, + ExtractorName.GIGAPATH, + ExtractorName.VIRCHOW2, + ], + "num_decoders": 1, + "decoder_out_dims": [768, 1536, 1536, 1536, 1280], +} + + +def get_ticon_key(extractor: ExtractorName) -> tuple[ExtractorName, int]: + """Get TICON key and embedding dimension for a given tile extractor.""" + if extractor not in TILE_EXTRACTOR_TO_TICON: + raise ValueError( + f"No TICON mapping for extractor {extractor}. " + f"Supported: {list(TILE_EXTRACTOR_TO_TICON.keys())}" + ) + return TILE_EXTRACTOR_TO_TICON[extractor] + + +def get_slopes(n: int) -> list[float]: + """Get ALiBi slopes for n attention heads.""" + + def get_slopes_power_of_2(n: int) -> list[float]: + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + +def scaled_dot_product_attention_alibi( + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Tensor, + dropout_p: float = 0.0, + training: bool = False, +) -> Tensor: + # try Flash Attention with ALiBi first + try: + with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]): + return torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_bias, + dropout_p=dropout_p if training else 0.0, + is_causal=False, + ) + except Exception: + pass + + scale_factor = 1 / math.sqrt(query.size(-1)) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight = attn_weight + attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + + if dropout_p > 0.0: + attn_weight = torch.dropout(attn_weight, dropout_p, train=training) + + return attn_weight @ value + + +## TICON BACKBONE COMPONENTS +class Mlp(nn.Module): + """MLP with SwiGLU activation (used in TICON transformer blocks).""" + + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + mlp_ratio: float = 16 / 3, + bias: bool = True, + ) -> None: + super().__init__() + if hidden_features is None: + hidden_features = int(in_features * mlp_ratio) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(hidden_features // 2, in_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x1, x2 = x.chunk(2, dim=-1) + x = self.act(x1) * x2 + return self.fc2(x) + + +class ProjectionMlp(nn.Module): + """Projection MLP for input/output transformations with LayerNorm.""" + + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: int, + bias: bool = True, + ) -> None: + super().__init__() + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = nn.SiLU() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.norm = nn.LayerNorm(out_features) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return self.norm(x) + + +class Attention(nn.Module): + """Multi-head attention with ALiBi spatial bias for TICON.""" + + def __init__( + self, + dim: int, + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + context_dim: int | None = None, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + context_dim = context_dim or dim + + self.q_proj = nn.Linear(dim, dim, bias=qkv_bias) + self.k_proj = nn.Linear(context_dim, dim, bias=qkv_bias) + self.v_proj = nn.Linear(context_dim, dim, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + + # ALiBi slopes (registered as buffer for proper device handling) + slopes = torch.tensor(get_slopes(num_heads), dtype=torch.float32) + self.register_buffer("slopes", slopes[None, :, None, None]) + + def forward( + self, + x: Float[Tensor, "b n_q d"], + coords: Float[Tensor, "b n_q 2"], + context: Float[Tensor, "b n_k d_k"] | None = None, + context_coords: Float[Tensor, "b n_k 2"] | None = None, + ) -> Float[Tensor, "b n_q d"]: + if context is None: + context = x + context_coords = coords + + b, n_q, d = x.shape + n_k = context.shape[1] + h = self.num_heads + + # Project queries, keys, values + q = self.q_proj(x).reshape(b, n_q, h, d // h).transpose(1, 2) + k = self.k_proj(context).reshape(b, n_k, h, d // h).transpose(1, 2) + v = self.v_proj(context).reshape(b, n_k, h, d // h).transpose(1, 2) + + # Validate coordinates are available + if coords is None or context_coords is None: + raise ValueError( + "Coordinates must be provided for spatial attention with ALiBi bias" + ) + # Compute spatial distances for ALiBi + coords_exp = coords.unsqueeze(2).expand(-1, -1, n_k, -1) + ctx_coords_exp = context_coords.unsqueeze(1).expand(-1, n_q, -1, -1) + euclid_dist = torch.sqrt(torch.sum((coords_exp - ctx_coords_exp) ** 2, dim=-1)) + + # Apply ALiBi bias + attn_bias = -self.slopes * euclid_dist[:, None, :, :] + + # Attention with ALiBi + x = scaled_dot_product_attention_alibi( + q, + k, + v, + attn_bias=attn_bias, + training=self.training, + ) + + x = x.transpose(1, 2).reshape(b, n_q, d) + return self.proj(x) + + +class ResidualBlock(nn.Module): + """Residual connection with optional layer scale and stochastic depth.""" + + def __init__( + self, + drop_prob: float, + norm: nn.Module, + fn: nn.Module, + gamma: nn.Parameter | None, + ): + super().__init__() + self.norm = norm + self.fn = fn + self.keep_prob = 1 - drop_prob + self.gamma = gamma + + def forward(self, x: Tensor, **kwargs) -> Tensor: + fn_out = self.fn(self.norm(x), **kwargs) + + if self.gamma is not None: + fn_out = self.gamma * fn_out + + if self.keep_prob == 1.0 or not self.training: + return x + fn_out + + # Stochastic depth + mask = fn_out.new_empty(x.shape[0]).bernoulli_(self.keep_prob)[:, None, None] + return x + fn_out * mask / self.keep_prob + + +class Block(nn.Module): + """Transformer block with attention and MLP.""" + + def __init__( + self, + dim: int, + drop_path: float, + norm_layer: Callable[[int], nn.Module], + context_dim: int | None, + layer_scale: bool = True, + attn_kwargs: Mapping = {}, + ) -> None: + super().__init__() + + gamma1 = nn.Parameter(torch.ones(dim)) if layer_scale else None + gamma2 = nn.Parameter(torch.ones(dim)) if layer_scale else None + + self.residual1 = ResidualBlock( + drop_path, + norm_layer(dim), + Attention(dim, context_dim=context_dim, **attn_kwargs), + gamma1, + ) + self.residual2 = ResidualBlock( + drop_path, + norm_layer(dim), + Mlp(in_features=dim), + gamma2, + ) + + def forward( + self, + x: Tensor, + coords: Tensor, + context: Tensor | None = None, + context_coords: Tensor | None = None, + ) -> Tensor: + x = self.residual1( + x, + context=context, + coords=coords, + context_coords=context_coords, + ) + x = self.residual2(x) + return x + + +class Transformer(nn.Module): + """Transformer encoder/decoder stack for TICON.""" + + def __init__( + self, + embed_dim: int, + norm_layer: Callable[[int], nn.Module], + depth: int, + drop_path_rate: float, + context_dim: int | None = None, + block_kwargs: Mapping[str, Any] = {}, + ): + super().__init__() + self.embed_dim = embed_dim + self.n_blocks = depth + + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + drop_path=drop_path_rate, + norm_layer=norm_layer, + context_dim=context_dim, + **block_kwargs, + ) + for _ in range(depth) + ] + ) + + def forward( + self, + x: Tensor, + coords: Tensor, + return_layers: set[int], + contexts: list[Tensor] | None = None, + context_coords: Tensor | None = None, + ) -> dict[int, Tensor]: + outputs = {} + if 0 in return_layers: + outputs[0] = x + + for blk_idx, blk in enumerate(self.blocks): + context = contexts[blk_idx] if contexts is not None else None + x = blk( + x, + coords=coords, + context=context, + context_coords=context_coords, + ) + if blk_idx + 1 in return_layers: + outputs[blk_idx + 1] = x + + return outputs + + +class TiconBackbone(nn.Module): + """ + TICON Encoder-Decoder backbone. + + This is the core TICON model that contextualizes tile embeddings + using spatial attention with ALiBi positional bias. + """ + + def __init__( + self, + in_dims: list[int], + tile_encoder_keys: list[str], + transformers_kwargs: Mapping[str, Any], + encoder_kwargs: Mapping[str, Any], + decoder_kwargs: Mapping[str, Any] = {}, + norm_layer_type: str = "LayerNorm", + norm_layer_kwargs: Mapping[str, Any] = {"eps": 1e-5}, + final_norm_kwargs: Mapping[str, Any] = {"elementwise_affine": True}, + out_layer: int = -1, + num_decoders: int = 0, + decoder_out_dims: list[int] = [], + **kwargs, # Ignore extra kwargs like patch_size + ): + super().__init__() + + norm_layer: Callable[[int], nn.Module] = partial( + getattr(nn, norm_layer_type), **norm_layer_kwargs + ) + + self.encoder = Transformer( + **transformers_kwargs, + **encoder_kwargs, + norm_layer=norm_layer, + ) + + self.tile_encoder_keys = tile_encoder_keys + self.embed_dim = self.encoder.embed_dim + self.out_layer = out_layer % (len(self.encoder.blocks) + 1) + self.enc_norm = norm_layer(self.embed_dim, **final_norm_kwargs) + + # Input projections for each tile encoder + self.input_proj_dict = nn.ModuleDict( + { + f"input_proj_{key}": ProjectionMlp( + in_features=in_dims[i], + hidden_features=self.embed_dim, + out_features=self.embed_dim, + ) + for i, key in enumerate(tile_encoder_keys) + } + ) + + def init_weights(self) -> "TiconBackbone": + """Initialize model weights.""" + self.apply(_init_weights) + return self + + def forward( + self, + x: Float[Tensor, "b n d"], + relative_coords: Float[Tensor, "b n 2"], + tile_encoder_key: str, + ) -> Float[Tensor, "b n d"]: + """ + Forward pass through TICON encoder. + + Args: + x: Tile embeddings [B, N, D] + relative_coords: Tile coordinates [B, N, 2] + tile_encoder_key: Which input projection to use + + Returns: + Contextualized embeddings [B, N, embed_dim] + """ + # Project input to TICON embedding dimension + x = self.input_proj_dict[f"input_proj_{tile_encoder_key}"](x) + + # Run through transformer encoder + encoder_outputs = self.encoder( + x, + coords=relative_coords, + return_layers={self.out_layer}, + ) + + # Apply final normalization + return self.enc_norm(encoder_outputs[self.out_layer]) + + +def _init_weights(m: nn.Module) -> None: + """Initialize model weights following JAX ViT convention.""" + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) and m.elementwise_affine: + nn.init.constant_(m.weight, 1.0) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +def load_ticon_backbone( + device: DeviceLikeType = "cuda", + model_cfg: dict | None = None, +) -> TiconBackbone: + """Load pretrained TICON backbone from HuggingFace.""" + model_cfg = TICON_MODEL_CFG if model_cfg is None else model_cfg + + # Download checkpoint from HuggingFace + ckpt_path = hf_hub_download( + repo_id="varunb/TICON", + filename="backbone/checkpoint.pth", + repo_type="model", + ) + + # Create model on meta device (no memory allocation) + with torch.device("meta"): + model = TiconBackbone(**model_cfg) + + # Move to target device and initialize weights + model.to_empty(device=device) + model.init_weights() + + # Load pretrained weights + state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) + state_dict = { + k.removeprefix("backbone."): v + for k, v in state_dict.items() + if k.startswith("backbone.") + } + + model.load_state_dict(state_dict, strict=False) + model.eval() + + return model + + +## TICON BACKBONE END ## + + +## TICON ENCODER CLASS ## +class TiconEncoder(Encoder): + def __init__( + self, + device: DeviceLikeType = "cuda", + precision: torch.dtype = torch.float32, + ): + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + ticon_model = load_ticon_backbone(device=device) + + super().__init__( + model=ticon_model, + identifier=EncoderName.TICON, + precision=precision, + required_extractors=list(TILE_EXTRACTOR_TO_TICON.keys()), + ) + + self._device = torch.device(device) + self._current_extractor = None + + def _prepare_coords(self, coords: CoordsInfo, num_tiles: int) -> Tensor: + """Prepare coordinates tensor for TICON.""" + if coords is None: + print("No coords provided, using zeros.") + return torch.zeros( + 1, num_tiles, 2, device=self._device, dtype=torch.float32 + ) + # CoordsInfo: get relative positions + if isinstance(coords, CoordsInfo): + coords_data = coords.coords_um + if coords.tile_size_um and coords.tile_size_um > 0: + # converting to grid-indices to get relative positions (is optional only, can be left out) + coords_data = coords.coords_um / coords.tile_size_um + else: + coords_data = coords.coords_um + else: + coords_data = coords + + # convert CoordsInfo to tensor + if not isinstance(coords_data, torch.Tensor): + coords_data = np.array(coords_data) + coords_tensor = torch.from_numpy(coords_data) + else: + coords_tensor = coords_data + + # adapt dimensions (add batch dim) + if coords_tensor.dim() == 2: + coords_tensor = coords_tensor.unsqueeze(0) # [1, N, 2] + assert ( + coords_tensor.shape[1] == num_tiles + ) # number of coords-pairs must match number of tiles + return coords_tensor.to(self._device, dtype=torch.float32) + + def _generate_slide_embedding( + self, + feats: torch.Tensor, + device: DeviceLikeType, + **kwargs, + ) -> np.ndarray: + """Generate contextualized slide embedding using TICON.""" + + # get extractor from kwargs + extractor = kwargs.get("extractor") + if extractor is None: + raise ValueError("extractor must be provided for TICON encoding") + + # Convert extractor-string to ExtractorName to be sure + if isinstance(extractor, str): + extractor = ExtractorName(extractor) + + tile_encoder_key, _ = get_ticon_key(extractor) + print(f"Using tile extractor: {tile_encoder_key} for ticon") + if feats.dim() == 2: + feats = feats.unsqueeze(0) # add batch dim + feats = feats.to(self._device, dtype=torch.float32) + + # get coords from kwargs + coords_tensor = kwargs.get("coords", None) + print( + f"Coords tensor shape: {coords_tensor.shape}" + if coords_tensor is not None + else "No coords tensor provided" + ) + # # check pytorch version for autocast compatibility + # is_legacy_autocast = "torch.cuda.amp" in autocast.__module__ + + # ac_kwargs = { + # "enabled": (self._device.type == "cuda"), + # "dtype": torch.bfloat16, + # } + # # if its the new version: add device_type + # if not is_legacy_autocast: + # ac_kwargs["device_type"] = "cuda" + + # Inference mode only/ without autocast + with torch.no_grad(): + try: + contextualized = self.model( + x=feats, + relative_coords=coords_tensor, + tile_encoder_key=tile_encoder_key, + ) + except RuntimeError as e: + _logger.error( + f"RuntimeError during TICON encoding without autocast: {e}. Retrying with autocast." + ) + raise e + + # try: + # with autocast(**ac_kwargs): + # contextualized = self.model( + # x=feats, + # relative_coords=coords_tensor, + # tile_encoder_key=tile_encoder_key, + # ) + # except RuntimeError as e: + # _logger.error( + # f"RuntimeError during TICON encoding with autocast {ac_kwargs}: {e}. Retrying without autocast." + # ) + # contextualized = self.model( + # x=feats, + # relative_coords=coords_tensor, + # tile_encoder_key=tile_encoder_key, + # ) + + return contextualized.detach().squeeze(0).cpu().numpy() + + # only pseudo-code so TiconEncoder can be instantiated + def _generate_patient_embedding( + self, + feats_list: list[torch.Tensor], + device: DeviceLikeType, + **kwargs, + ) -> np.ndarray: + contextualized = [ + self._generate_slide_embedding(feats, device, **kwargs) + for feats in feats_list + ] + return np.concatenate(contextualized, axis=0) + + def encode_slides_( + self, + output_dir: Path, + feat_dir: Path, + device: DeviceLikeType, + generate_hash: bool = True, + **kwargs, + ) -> None: + if generate_hash: + encode_dir = f"{self.identifier}-slide-{get_processing_code_hash(Path(__file__))[:8]}" + else: + encode_dir = f"{self.identifier}-slide" + + encode_dir = output_dir / encode_dir + os.makedirs(encode_dir, exist_ok=True) + + self.model.to(device).eval() + + h5_files = [f for f in os.listdir(feat_dir) if f.endswith(".h5")] + + for filename in (progress := tqdm(h5_files)): + h5_path = os.path.join(feat_dir, filename) + slide_name = Path(filename).name + progress.set_description(slide_name) + + output_path = (encode_dir / slide_name).with_suffix(".h5") + if output_path.exists(): + _logger.info(f"Skipping {slide_name}: output exists") + continue + # + try: + feats, coords = self._validate_and_read_features(h5_path) + except ValueError as e: + tqdm.write(s=str(e)) + continue + try: + feats, coords, extractor = self._read_h5(h5_path) + except ValueError as e: + tqdm.write(str(e)) + continue + try: + target_extractor = ExtractorName(extractor) # str → Enum + except ValueError: + target_extractor = extractor # Schon Enum + + # option to save coords because it is not a classical slide, also set feat_type to tile + coords_um_np = coords.coords_um + print( + f"Coords um shape: {coords_um_np.shape}" + if coords is not None + else "No coords found" + ) + + # CoordsInfo -> absolute coords in µm + if isinstance(coords_um_np, torch.Tensor): + coords_um_np = coords_um_np.detach().cpu().numpy() + print(f"Converted coords to numpy array, shape: {coords_um_np.shape}") + else: + coords_um_np = np.asarray(coords_um_np) + print(f"Coords as numpy array, shape: {coords_um_np.shape}") + + slide_embedding = self._generate_slide_embedding( + feats, + device, + coords=self._prepare_coords(coords, feats.shape[0]), + extractor=target_extractor, + ) + + self._save_features_( + output_path=output_path, + feats=slide_embedding, + feat_type="tile", + coords=coords_um_np, + tile_size_um=float(coords.tile_size_um) + if coords.tile_size_um is not None + else None, + tile_size_px=int(coords.tile_size_px) + if coords.tile_size_px is not None + else None, + unit="um", + ) diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index a1844526..ebcf3a03 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -122,6 +122,7 @@ def extract_( cache_dir: Path | None, cache_tiles_ext: ImageExtension, extractor: ExtractorName | Extractor, + tile_extractor: ExtractorName, tile_size_px: TilePixels, tile_size_um: Microns, max_workers: int, @@ -222,6 +223,11 @@ def extract_( extractor = plip() + case ExtractorName.TICON: + from stamp.preprocessing.extractor.ticon import ticon + + extractor = ticon(tile_extractor=tile_extractor) + case ExtractorName.EMPTY: from stamp.preprocessing.extractor.empty import empty @@ -238,7 +244,8 @@ def extract_( code_hash = get_processing_code_hash(Path(__file__))[:8] extractor_id = extractor.identifier - + if extractor_id == ExtractorName.TICON and tile_extractor is not None: + extractor_id = f"{extractor_id}-{tile_extractor}" _logger.info(f"Using extractor {extractor.identifier}") if cache_dir: @@ -330,6 +337,8 @@ def extract_( h5_fp.attrs["stamp_version"] = stamp.__version__ h5_fp.attrs["extractor"] = str(extractor.identifier) + if tile_extractor is not None: + h5_fp.attrs["tile_extractor"] = str(tile_extractor) h5_fp.attrs["unit"] = "um" h5_fp.attrs["tile_size_um"] = tile_size_um # changed in v2.1.0 h5_fp.attrs["tile_size_px"] = tile_size_px diff --git a/src/stamp/preprocessing/config.py b/src/stamp/preprocessing/config.py index 244d70dd..b8595ae6 100644 --- a/src/stamp/preprocessing/config.py +++ b/src/stamp/preprocessing/config.py @@ -28,6 +28,7 @@ class ExtractorName(StrEnum): MUSK = "musk" MSTAR = "mstar" PLIP = "plip" + TICON = "ticon" EMPTY = "empty" @@ -44,6 +45,7 @@ class PreprocessingConfig(BaseModel, arbitrary_types_allowed=True): tile_size_um: Microns = Microns(256.0) tile_size_px: TilePixels = TilePixels(224) extractor: ExtractorName + tile_extractor: ExtractorName max_workers: int = 8 device: str = "cuda" if torch.cuda.is_available() else "cpu" generate_hash: bool = True diff --git a/src/stamp/preprocessing/extractor/ticon.py b/src/stamp/preprocessing/extractor/ticon.py new file mode 100644 index 00000000..411c722d --- /dev/null +++ b/src/stamp/preprocessing/extractor/ticon.py @@ -0,0 +1,231 @@ +from typing import Callable, cast + +try: + import timm + import torch + import torch.nn as nn + from PIL import Image + from timm.data.config import resolve_data_config + from timm.data.transforms_factory import create_transform + from timm.layers.mlp import SwiGLUPacked + from torchvision import transforms +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "TICON dependencies not installed. " + "Please reinstall stamp using `pip install 'stamp[ticon]'`" + ) from e + +from stamp.encoding.encoder.ticon_encoder import ( + TILE_EXTRACTOR_TO_TICON, + get_ticon_key, + load_ticon_backbone, +) +from stamp.preprocessing.config import ExtractorName +from stamp.preprocessing.extractor import Extractor + + +class _Virchow2ClsOnly(nn.Module): + """Wrapper for Virchow2 to return only CLS token.""" + + def __init__(self, model: nn.Module) -> None: + super().__init__() + self.model = model + + def forward(self, batch: torch.Tensor) -> torch.Tensor: + return self.model(batch)[:, 0] + + +def _create_tile_encoder( + extractor: ExtractorName, +) -> tuple[nn.Module, Callable[[Image.Image], torch.Tensor]]: + """Create tile encoder model and transform for given extractor.""" + if extractor == ExtractorName.H_OPTIMUS_1: + model = timm.create_model( + "hf-hub:bioptimus/H-optimus-1", + pretrained=True, + init_values=1e-5, + dynamic_img_size=False, + ) + transform = transforms.Compose( + [ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.707223, 0.578729, 0.703617), + std=(0.211883, 0.230117, 0.177517), + ), + ] + ) + return model, transform + + elif extractor == ExtractorName.GIGAPATH: + model = timm.create_model( + "hf_hub:prov-gigapath/prov-gigapath", + pretrained=True, + init_values=1e-5, + dynamic_img_size=False, + ) + transform = transforms.Compose( + [ + transforms.Resize(224), + transforms.ToTensor(), + transforms.Normalize( + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + ), + ] + ) + return model, transform + + elif extractor == ExtractorName.UNI2: + timm_kwargs = { + "img_size": 224, + "patch_size": 14, + "depth": 24, + "num_heads": 24, + "init_values": 1e-5, + "embed_dim": 1536, + "mlp_ratio": 2.66667 * 2, + "num_classes": 0, + "no_embed_class": True, + "mlp_layer": SwiGLUPacked, + "act_layer": torch.nn.SiLU, + "reg_tokens": 8, + "dynamic_img_size": True, + } + model = timm.create_model( + "hf-hub:MahmoodLab/UNI2-h", + pretrained=True, + **timm_kwargs, + ) + transform = cast( + Callable[[Image.Image], torch.Tensor], + create_transform(**resolve_data_config(model.pretrained_cfg, model=model)), + ) + return model, transform + + elif extractor == ExtractorName.VIRCHOW2: + base_model = timm.create_model( + "hf-hub:paige-ai/Virchow2", + pretrained=True, + mlp_layer=SwiGLUPacked, + act_layer=torch.nn.SiLU, + ) + model = _Virchow2ClsOnly(base_model) + transform = cast( + Callable[[Image.Image], torch.Tensor], + create_transform( + **resolve_data_config(base_model.pretrained_cfg, model=base_model) + ), + ) + return model, transform + + elif extractor == ExtractorName.CONCH1_5: + try: + from transformers import AutoModel + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "CONCH v1.5 dependencies not installed. " + "Please reinstall stamp using `pip install 'stamp[conch1_5]'`" + ) from e + + titan = AutoModel.from_pretrained("MahmoodLab/TITAN", trust_remote_code=True) + model, transform = titan.return_conch() + return model, transform + + else: + raise ValueError( + f"Unsupported tile extractor for TICON: {extractor}. " + f"Supported: {list(TILE_EXTRACTOR_TO_TICON.keys())}" + ) + + +### TICON Isolated Mode Extractor ### +class TICON(nn.Module): + """TICON in Isolated Mode - processes each tile independently.""" + + def __init__( + self, + tile_extractor: ExtractorName, + device: str = "cuda", + ): + super().__init__() + self._device = torch.device(device) + self.tile_extractor = tile_extractor + + # Validate extractor is supported by TICON + if tile_extractor not in TILE_EXTRACTOR_TO_TICON: + raise ValueError( + f"Tile extractor {tile_extractor} is not supported by TICON. " + f"Supported: {list(TILE_EXTRACTOR_TO_TICON.keys())}" + ) + + # Get TICON key and embedding dimension + self.tile_encoder_key, self.embed_dim = get_ticon_key(tile_extractor) + + # Stage 1: Create tile encoder + self.tile_encoder, self._transform = _create_tile_encoder(tile_extractor) + + # Stage 2: Load TICON backbone + self.ticon = load_ticon_backbone(device=device) + + self.to(self._device) + self.eval() + + def get_transform(self) -> Callable[[Image.Image], torch.Tensor]: + """Get image transform for this tile extractor.""" + return self._transform + + @torch.inference_mode() + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through TICON Isolated Mode.""" + x = x.to(self._device, non_blocking=True) + + # Stage 1: Extract tile features + + # with torch.amp.autocast( + # device_type="cuda", + # dtype=torch.bfloat16, + # enabled=(self._device.type == "cuda"), + # ): + emb = self.tile_encoder(x) + + # Handle different output shapes (some models return [B, N, D]) + if emb.dim() == 3: + emb = emb[:, 0] # Take CLS token + + # Add sequence dimension for TICON: [B, D] -> [B, 1, D] + emb = emb.unsqueeze(1) + + # Stage 2: TICON (single tile = no spatial context, use zero coords) + coords = torch.zeros( + emb.size(0), + 1, + 2, + device=self._device, + dtype=torch.float32, + ) + + # with torch.amp.autocast( + # device_type="cuda", + # dtype=torch.bfloat16, + # enabled=(self._device.type == "cuda"), + # ): + out = self.ticon( + x=emb.float(), # TICON expects float32 input + relative_coords=coords, + tile_encoder_key=self.tile_encoder_key, + ) + + # Remove sequence dimension: [B, 1, D] -> [B, D] + return out.squeeze(1) + + +def ticon(tile_extractor: ExtractorName) -> Extractor[TICON]: + """Create TICON Isolated Mode extractor.""" + model = TICON(tile_extractor=tile_extractor) + return Extractor( + model=model, + transform=model.get_transform(), + identifier=ExtractorName.TICON, + )