Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/stamp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _run_cli(args: argparse.Namespace) -> None:
tile_size_um=config.preprocessing.tile_size_um,
tile_size_px=config.preprocessing.tile_size_px,
extractor=config.preprocessing.extractor,
tile_extractor=config.preprocessing.tile_extractor,
max_workers=config.preprocessing.max_workers,
device=config.preprocessing.device,
default_slide_mpp=config.preprocessing.default_slide_mpp,
Expand Down
7 changes: 5 additions & 2 deletions src/stamp/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@ preprocessing:
# Extractor to use for feature extractor. Possible options are "ctranspath",
# "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom",
# "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow",
# "virchow-full", "musk", "mstar", "plip"
# "virchow-full", "musk", "mstar", "plip", "ticon"
# Some of them require requesting access to the respective authors beforehand.
extractor: "chief-ctranspath"

# Device to run feature extraction on ("cpu", "cuda", "cuda:0", etc.)
device: "cuda"

# Optional settings:
# if "ticon" is selected, specify model to enhance
# e.g. "h-optimus-1, "virchow2","conch1_5", "uni2", "gigapath"
tile_extractor: "h-optimus-1"

# Having a cache dir will speed up extracting features multiple times,
# e.g. with different feature extractors. Optional.
Expand Down Expand Up @@ -249,7 +252,7 @@ heatmaps:

slide_encoding:
# Encoder to use for slide encoding. Possible options are "cobra",
# "eagle", "titan", "gigapath", "chief", "prism", "madeleine".
# "eagle", "titan", "gigapath", "chief", "prism", "madeleine", "ticon".
encoder: "chief"

# Directory to save the output files.
Expand Down
10 changes: 10 additions & 0 deletions src/stamp/encoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def init_slide_encoder_(

selected_encoder: Encoder = Prism()

case EncoderName.TICON:
from stamp.encoding.encoder.ticon_encoder import TiconEncoder

selected_encoder: Encoder = TiconEncoder()

case Encoder():
selected_encoder = encoder

Expand Down Expand Up @@ -155,6 +160,11 @@ def init_patient_encoder_(

selected_encoder: Encoder = Prism()

case EncoderName.TICON:
from stamp.encoding.encoder.ticon_encoder import TiconEncoder

selected_encoder: Encoder = TiconEncoder()

case Encoder():
selected_encoder = encoder

Expand Down
1 change: 1 addition & 0 deletions src/stamp/encoding/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class EncoderName(StrEnum):
GIGAPATH = "gigapath"
MADELEINE = "madeleine"
PRISM = "prism"
TICON = "ticon"


class SlideEncodingConfig(BaseModel, arbitrary_types_allowed=True):
Expand Down
23 changes: 16 additions & 7 deletions src/stamp/encoding/encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@ def encode_slides_(
tqdm.write(s=str(e))
continue

slide_embedding = self._generate_slide_embedding(
feats, device, coords=coords
)
slide_embedding = self._generate_slide_embedding(feats, device, **kwargs)
self._save_features_(
output_path=output_path, feats=slide_embedding, feat_type="slide"
output_path=output_path,
feats=slide_embedding,
feat_type="slide",
**kwargs,
)

def encode_patients_(
Expand Down Expand Up @@ -133,7 +134,7 @@ def encode_patients_(
for _, row in group.iterrows():
slide_filename = row[filename_label]
h5_path = os.path.join(feat_dir, slide_filename)
feats, _ = self._validate_and_read_features(h5_path)
feats, coords = self._validate_and_read_features(h5_path)
feats_list.append(feats)

if not feats_list:
Expand All @@ -149,7 +150,10 @@ def encode_patients_(

@abstractmethod
def _generate_slide_embedding(
self, feats: torch.Tensor, device, **kwargs
self,
feats: torch.Tensor,
device,
**kwargs,
) -> np.ndarray:
"""Generate slide embedding. Must be implemented by subclasses."""
pass
Expand Down Expand Up @@ -193,14 +197,19 @@ def _read_h5(
return feats, coords, _resolve_extractor_name(extractor)

def _save_features_(
self, output_path: Path, feats: np.ndarray, feat_type: str
self, output_path: Path, feats: np.ndarray, feat_type: str, **kwargs
) -> None:
with (
NamedTemporaryFile(dir=output_path.parent, delete=False) as tmp_h5_file,
h5py.File(tmp_h5_file, "w") as f,
):
try:
f["feats"] = feats
f["coords"] = kwargs.get("coords", np.array([]))
if "tile_size_um" in kwargs and kwargs["tile_size_um"] is not None:
f.attrs["tile_size_um"] = float(kwargs["tile_size_um"])
if "tile_size_px" in kwargs and kwargs["tile_size_px"] is not None:
f.attrs["tile_size_px"] = int(kwargs["tile_size_px"])
f.attrs["version"] = stamp.__version__
f.attrs["encoder"] = str(self.identifier)
f.attrs["precision"] = str(self.precision)
Expand Down
Loading
Loading