Skip to content
This repository was archived by the owner on Mar 8, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ethos/commands/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
SofaPredictionDataset,
ICUMortalityDataset,
)
from ethos.datasets.mimic import DrgPredictionDataset, ICUReadmissionDataset
from ethos.datasets.mimic import DrgPredictionDataset, ICUReadmissionDataset, ICUPredictionDataset
from ethos.inference import Test, run_inference
from ethos.tokenize import SpecialToken, Vocabulary
from ethos.utils import load_model_from_checkpoint, load_data, get_logger
Expand Down Expand Up @@ -96,6 +96,9 @@ def infer(
elif test == Test.ICU_READMISSION:
dataset_cls = ICUReadmissionDataset
stoi = [ICU_ADMISSION_STOKEN, DISCHARGE_STOKEN] + stoi
elif test == Test.ICU_PREDICTION:
dataset_cls = ICUPredictionDataset
stoi = [ICU_ADMISSION_STOKEN, DISCHARGE_STOKEN] + stoi
else:
raise ValueError(f"Unknown test: {test}, available")

Expand Down
2 changes: 0 additions & 2 deletions ethos/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@
@option("--device", default="cuda", type=Choice(["cuda", "cpu"]))
@option("--dtype", default="bfloat16", type=Choice(["float32", "bfloat16", "float16"]))
@option("--no_compile", is_flag=True, help="Don't compile the model using Triton.")
# optional
@option("--ctx_no_grad", is_flag=True, help="Don't compute gradient for the context tokens.")
def train(**kwargs):
"""This training script can be run both on a single gpu in debug mode, and also in a larger
training run with distributed data parallel (ddp).
Expand Down
8 changes: 4 additions & 4 deletions ethos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
PROJECT_DATA: Path = files("ethos.data")
PROJECT_ROOT: Path = (PROJECT_DATA / "../..").resolve()

ADMISSION_STOKEN = "INPATIENT_ADMISSION_START"
DISCHARGE_STOKEN = "INPATIENT_ADMISSION_END"
ADMISSION_STOKEN = "INPATIENT//ADMISSION"
DISCHARGE_STOKEN = "INPATIENT//DISCHARGE"
# present only in the MIMIC dataset
ICU_ADMISSION_STOKEN = "ICU_STAY_START"
ICU_DISCHARGE_STOKEN = "ICU_STAY_END"
ICU_ADMISSION_STOKEN = "ICU//ADMISSION"
ICU_DISCHARGE_STOKEN = "ICU//DISCHARGE"
Binary file added ethos/data/mimic_drug_to_atc.csv.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion ethos/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .admission_mortality import AdmissionMortalityDataset, AdmissionMortalityNextTokenDataset
from .base import TimelineDataset
from .mimic import SofaPredictionDataset, ICUMortalityDataset
from .mimic import SofaPredictionDataset, ICUMortalityDataset, ICUPredictionDataset
from .mortality import MortalityDataset, SingleAdmissionMortalityDataset
from .readmission import ReadmissionDataset
4 changes: 2 additions & 2 deletions ethos/datasets/admission_mortality.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __getitem__(self, idx) -> tuple[th.Tensor, dict]:
patient_idx = self._get_patient_idx(admission_idx)
data_start_idx = self.patient_offsets[patient_idx]

if admission_idx - data_start_idx - 1 > self.timeline_len:
if admission_idx - data_start_idx + 1 > self.timeline_len:
data_start_idx = admission_idx + 1 - self.timeline_len

patient_context = self._get_patient_context(data_start_idx)
Expand All @@ -55,7 +55,7 @@ def __getitem__(self, idx) -> tuple[th.Tensor, dict]:
"true_token_time": (self.times[toi_idx] - self.times[admission_idx]).item(),
"patient_id": self.patient_ids[patient_idx].item(),
"patient_age": self.times[data_start_idx].item(),
"admission_token_idx": admission_idx.item(),
"data_idx": admission_idx.item(),
"year": year,
}

Expand Down
11 changes: 9 additions & 2 deletions ethos/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,21 @@ def __init__(self, data: dict, encode: Callable, block_size: int = 2048):
# vocab encode function that translates strings to integers
self.encode: Callable = encode

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(len={len(self):,}, "
f"patient_num={len(set(self.patient_ids)):,})"
)

def __len__(self) -> int:
return len(self.times) - self.timeline_len

def __getitem__(self, idx: int) -> tuple[th.Tensor, th.Tensor]:
patient_context = self._get_patient_context(idx)
timeline = self.tokens[idx : idx + self.timeline_len + 1]
timeline = self.tokens[idx: idx + self.timeline_len + 1]
x = th.cat((patient_context, timeline[:-1]))
y = th.cat((patient_context, timeline[1:]))
y[: self.context_len] = -100
return x, y

def _get_patient_context(self, idx: int) -> th.Tensor:
Expand Down Expand Up @@ -82,7 +89,7 @@ def _get_indices_of_stokens(self, stokens: str | Sequence[str]) -> np.ndarray[np

@staticmethod
def _match_next_value(
to_match: Sequence, match_with: Sequence, always_match: bool = True
to_match: Sequence, match_with: Sequence, always_match: bool = True
) -> np.ndarray[int | float]:
"""
Return the next closest values in `match_with` for every corresponding value in `to_match`.
Expand Down
70 changes: 62 additions & 8 deletions ethos/datasets/mimic.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def __getitem__(self, idx: int) -> tuple[th.Tensor, dict]:
data_start_idx = self.patient_offsets[patient_idx]
# shorten the input timeline if the patient history is too long, -1 because we include
# the admission token
if admission_idx - data_start_idx - 1 > self.timeline_len:
if admission_idx - data_start_idx + 1 > self.timeline_len:
data_start_idx = admission_idx + 1 - self.timeline_len
patient_context = self._get_patient_context(data_start_idx)
timeline = self.tokens[data_start_idx : admission_idx + 1]
timeline = self.tokens[data_start_idx: admission_idx + 1]
x = th.cat((patient_context, timeline))
return x, {
"expected": self.tokens[admission_idx + 3].item(),
Expand Down Expand Up @@ -116,7 +116,7 @@ def _exclude_too_short_stays(self):
times = self.times.numpy()
for i, adm_idx in enumerate(self.admission_indices):
offset = bisect_left(
times[adm_idx : self._get_next_timeline_start(adm_idx)],
times[adm_idx: self._get_next_timeline_start(adm_idx)],
times[adm_idx] + self.TIME_OFFSET,
)
self.admission_indices[i] += offset - 1
Expand Down Expand Up @@ -159,7 +159,7 @@ def __init__(self, data, encode, block_size: int):
)
adm_dc_or_end_indices = self._match_next_value(adm_indices, dc_or_end_indices)
has_icu_stay = (adm_icu_dc_indices < adm_dc_or_end_indices) | (
adm_icu_adm_indices < adm_dc_or_end_indices
adm_icu_adm_indices < adm_dc_or_end_indices
)

# discard cases where a patient dies during the first ICU stay
Expand All @@ -169,8 +169,8 @@ def __init__(self, data, encode, block_size: int):
(
adm_death_idx < adm_icu_dc_idx < adm_dc_or_end_idx
for adm_death_idx, adm_icu_dc_idx, adm_dc_or_end_idx in zip(
adm_death_indices, adm_icu_dc_indices, adm_dc_or_end_indices
)
adm_death_indices, adm_icu_dc_indices, adm_dc_or_end_indices
)
),
dtype=bool,
count=len(adm_indices),
Expand Down Expand Up @@ -199,11 +199,11 @@ def __getitem__(self, idx) -> tuple[th.Tensor, dict]:
patient_idx = self._get_patient_idx(icu_dc_idx)
data_start_idx = self.patient_offsets[patient_idx]

if icu_dc_idx - data_start_idx - 1 > self.timeline_len:
if icu_dc_idx - data_start_idx + 1 > self.timeline_len:
data_start_idx = icu_dc_idx + 1 - self.timeline_len

patient_context = self._get_patient_context(data_start_idx)
timeline = self.tokens[data_start_idx : icu_dc_idx + 1]
timeline = self.tokens[data_start_idx: icu_dc_idx + 1]
x = th.cat((patient_context, timeline))

if self.is_readmitted[idx]:
Expand All @@ -226,3 +226,57 @@ def __getitem__(self, idx) -> tuple[th.Tensor, dict]:
}
)
return x, y


class ICUPredictionDataset(InferenceDataset):
def __init__(self, data, encode, block_size: int):
super().__init__(data, encode, block_size)
self.adm_indices = self._get_indices_of_stokens(ADMISSION_STOKEN)
dc_or_icu_adm_or_death_indices = self._get_indices_of_stokens([
DISCHARGE_STOKEN, ICU_ADMISSION_STOKEN, SpecialToken.DEATH
])

if self.adm_indices[-1] > dc_or_icu_adm_or_death_indices[-1]:
self.adm_indices = self.adm_indices[:-1]

self.adm_dc_or_icu_adm_indices = self._match_next_value(self.adm_indices,
dc_or_icu_adm_or_death_indices)
icu_adm_indices = self._get_indices_of_stokens(ICU_ADMISSION_STOKEN)

self.ended_up_in_icu = np.isin(self.adm_dc_or_icu_adm_indices, icu_adm_indices)

def __len__(self) -> int:
return len(self.adm_indices)

def __getitem__(self, idx) -> tuple[th.Tensor, dict]:
adm_idx = self.adm_indices[idx] + 2
patient_idx = self._get_patient_idx(adm_idx)
data_start_idx = self.patient_offsets[patient_idx]

if adm_idx + 1 - data_start_idx > self.timeline_len:
data_start_idx = adm_idx + 1 - self.timeline_len

patient_context = self._get_patient_context(data_start_idx)
timeline = self.tokens[data_start_idx: adm_idx + 1]
x = th.cat((patient_context, timeline))

if self.ended_up_in_icu[idx]:
icu_adm_idx = int(self.adm_dc_or_icu_adm_indices[idx])
y = {
"expected": 1,
"true_token_dist": (icu_adm_idx - adm_idx).item(),
"true_token_time": (self.times[icu_adm_idx] - self.times[adm_idx]).item(),
}
else:
y = {"expected": 0}

year = self._get_year_at_timeline_start(patient_idx, self.times[data_start_idx])
y.update(
{
"patient_id": self.patient_ids[patient_idx].item(),
"patient_age": self.times[data_start_idx].item(),
"data_idx": adm_idx.item(),
"year": year,
}
)
return x, y
2 changes: 1 addition & 1 deletion ethos/datasets/mortality.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __getitem__(self, idx) -> tuple[th.Tensor, dict]:
patient_idx = self._get_patient_idx(self.admission_idx)
data_start_idx = self.patient_offsets[patient_idx]
data_end_idx = self.admission_idx + idx
if data_end_idx - 1 - data_start_idx > self.timeline_len:
if data_end_idx + 1 - data_start_idx > self.timeline_len:
data_start_idx = data_end_idx + 1 - self.timeline_len

timeline = self.tokens[data_start_idx : data_end_idx + 1]
Expand Down
4 changes: 2 additions & 2 deletions ethos/datasets/readmission.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __getitem__(self, idx) -> tuple[th.Tensor, dict]:
y = {"expected": 0}

data_start_idx = self.patient_offsets[patient_idx]
if discharge_idx - 1 - data_start_idx > self.timeline_len:
if discharge_idx + 1 - data_start_idx > self.timeline_len:
data_start_idx = discharge_idx + 1 - self.timeline_len

timeline = self.tokens[data_start_idx : discharge_idx + 1]
Expand All @@ -57,7 +57,7 @@ def __getitem__(self, idx) -> tuple[th.Tensor, dict]:
{
"patient_id": self.patient_ids[patient_idx].item(),
"patient_age": self.times[data_start_idx].item(),
"discharge_token_idx": discharge_idx.item(),
"data_idx": discharge_idx.item(),
}
)
if self.is_mimic:
Expand Down
1 change: 1 addition & 0 deletions ethos/inference/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ class Test(Enum):
SOFA_PREDICTION = "sofa"
ICU_MORTALITY = "icu_mortality"
ICU_READMISSION = "icu_readmission"
ICU_PREDICTION = "icu_prediction"
2 changes: 1 addition & 1 deletion ethos/inference/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def run_inference(loader, args, num_gpus: int = 8):
context_len = dataset.context_len
timeline_len = dataset.timeline_len
max_timeline_size = context_len + timeline_len
time_limit = 30 / 365.25 if test_name == Test.READMISSION else 2
time_limit = 30 / 365.25 if test_name in (Test.READMISSION, Test.ICU_PREDICTION) else 2
toi = th.tensor(vocab.encode(stoi), device=device, dtype=th.long)

results = []
Expand Down
17 changes: 10 additions & 7 deletions ethos/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,20 @@ def print_auc_roc_plot(res, gaussian_res, title="AUC-ROC", lw=2, clinical=False)


def process_readmission_results(
filename: str, admission_stoken: str, readmission_period: float
filename: str, admission_stoken: str, sample_id_col: str = "discharge_token_idx",
readmission_period: float | None = None,
) -> pd.DataFrame:
res_dir = PROJECT_ROOT / "results" / filename
df = pd.concat(pd.read_json(res_path) for res_path in res_dir.iterdir())
df.rename(columns={"actual": "actual_token", "patient_id": "subject_id"}, inplace=True)
df["actual"] = (df.actual_token == admission_stoken).astype(int)
df["expected"] = ((df.expected == 1) & (df.true_token_time <= readmission_period)).astype(int)
discharge_idx_name = (
"discharge_token_idx" if admission_stoken == ADMISSION_STOKEN else "discharge_idx"
)
df_gb = df.groupby(discharge_idx_name, dropna=False)

if readmission_period is not None:
df["expected"] = (
(df.expected == 1) & (df.true_token_time <= readmission_period)
).astype(int)

df_gb = df.groupby(sample_id_col, dropna=False)
return (
df_gb.agg(
{
Expand All @@ -219,7 +222,7 @@ def process_readmission_results(
"token_time": "mean",
"token_dist": "mean",
"patient_age": "first",
discharge_idx_name: "first",
sample_id_col: "first",
}
)
.join(df_gb.agg(count=("actual", "count")))
Expand Down
40 changes: 9 additions & 31 deletions ethos/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,6 @@
from torch.nn import functional as F


class LayerNorm(nn.Module):

def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)


class CausalSelfAttention(nn.Module):
def __init__(self, config, attention_weights: Optional[list] = None):
super().__init__()
Expand Down Expand Up @@ -104,9 +93,9 @@ def forward(self, x):
class Block(nn.Module):
def __init__(self, config, attention_weights: Optional[list] = None):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)
self.attn = CausalSelfAttention(config, attention_weights=attention_weights)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)

def forward(self, x):
Expand All @@ -119,7 +108,8 @@ def forward(self, x):
class ModelConfig:
block_size: int = 1024
vocab_size: int = (
50304 # number of tokens in the GPT-2 vocabulary, change to taste if you use a different vocab
50304
# number of tokens in the GPT-2 vocabulary, change to taste if you use a different vocab
)
n_layer: int = 12
n_head: int = 12
Expand Down Expand Up @@ -148,7 +138,7 @@ def __init__(self, config, return_attention=False):
h=nn.ModuleList(
[Block(config, self.attention_weights) for _ in range(config.n_layer)]
),
ln_f=LayerNorm(config.n_embd, bias=config.bias),
ln_f=nn.LayerNorm(config.n_embd, bias=config.bias),
)
)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
Expand Down Expand Up @@ -187,18 +177,17 @@ def _init_weights(self, module):
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

def forward(self, idx, targets=None, context_length=0):
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
assert (
t <= self.config.block_size
t <= self.config.block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

if self.return_attention:
self.attention_weights.clear()


tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
Expand All @@ -209,16 +198,7 @@ def forward(self, idx, targets=None, context_length=0):
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-1,
reduction="none",
)
if context_length:
loss.view(logits.size()[:2])[:, :context_length] = 0

loss = loss.mean()
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
else:
logits = self.lm_head(x[:, [-1], :])
loss = None
Expand Down Expand Up @@ -270,7 +250,7 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
#
idx_cond = (
idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :]
idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
)
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
Expand All @@ -285,8 +265,6 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):

@torch.no_grad()
def get_next_token(self, tokens, return_probs=False, top_k=None):
if tokens.size(1) > self.config.block_size:
tokens = tokens[:, -self.config.block_size :]
logits, _ = self(tokens)
logits = logits[:, -1, :]
if top_k is not None:
Expand Down
Loading