Skip to content

Commit

Permalink
Add additional benchmark config
Browse files Browse the repository at this point in the history
  • Loading branch information
pauldoucet committed Dec 8, 2024
1 parent 755efbb commit 87e586f
Showing 1 changed file with 28 additions and 5 deletions.
33 changes: 28 additions & 5 deletions src/hest/bench/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import os
from operator import itemgetter
from typing import List, Optional, Tuple
from typing import Callable, List, Optional, Tuple
from dataclasses import dataclass, asdict, field
from argparse import Namespace

Expand Down Expand Up @@ -75,14 +75,27 @@ class BenchmarkConfig:
"""
seed: int = 1
overwrite: bool = False

bench_data_root: Optional[str] = 'eval/bench_data'
# Benchmark data will automatically be downloaded to this path

embed_dataroot: Optional[str] = 'eval/ST_data_emb'
# Embeddings generated during benchmarking will be saved to this path

weights_root: Optional[str] = 'eval/fm_v1'
# Path to patch encoder weights

results_dir: Optional[str] = 'eval/ST_pred_results'
private_weights_root: Optional[str] = None
exp_code: Optional[str] = None
# Path to benchmark results

batch_size: int = 128
# Batch size used during embedding extraction

num_workers: int = 1
# Number of workers used during embedding extraction

private_weights_root: Optional[str] = None
exp_code: Optional[str] = None
gene_list: str = 'var_50genes.json'
method: str = 'ridge'
alpha: Optional[float] = None
Expand Down Expand Up @@ -376,8 +389,18 @@ def set_seed(seed):
random.seed(seed)


def benchmark(encoder, enc_transf, precision, cli_args=None, **kwargs) -> Tuple[list, dict]:

def benchmark(encoder: torch.nn.Module, enc_transf: Callable, precision: torch.dtype, cli_args: dict=None, **kwargs) -> Tuple[list, dict]:
""" Benchmark a patch encoder on HEST-bench
Args:
encoder (torch.nn.Module): patch encoder to benchmark
enc_transf (Callable): transformation applied to `encoder` during inference
precision (torch.dtype): precision used by torch.cuda.amp.autocast() during inference for `encoder`
cli_args (dict): cli_arguments. Defaults to None.
**kwargs: lookup `BenchmarkConfig` for additional parameters
"""

# get default args - overwritten if using CLI, kwargs, or config file
args = Namespace(**asdict(BenchmarkConfig()))

Expand Down

0 comments on commit 87e586f

Please sign in to comment.