Skip to content

Commit a57796e

Browse files
committed
Add observation matrix form transformer
1 parent c400300 commit a57796e

File tree

3 files changed

+62
-3
lines changed

3 files changed

+62
-3
lines changed

decipher/processing/transformers.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import logging
2+
from datetime import timedelta
23
from pathlib import Path
3-
from typing import Any
4+
from typing import Any, Callable
45

56
import numpy as np
7+
import numpy.typing as npt
68
import pandas as pd
79
from sklearn.base import BaseEstimator, TransformerMixin
810

@@ -244,3 +246,41 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
244246
)
245247
.dropna(subset="value")
246248
).astype({"variable": "category", "value": "category"})
249+
250+
251+
class ObservationMatrix(BaseEstimator, TransformerMixin):
252+
"""Convert exams df to observations"""
253+
254+
def __init__(
255+
self, risk_agg_method: str | Callable = "max", months_per_bin: float = 3
256+
):
257+
self.risk_agg_method = risk_agg_method
258+
self.months_per_bin = months_per_bin
259+
super().__init__()
260+
261+
def fit(self, X: pd.DataFrame, y=None):
262+
CleanData(dtypes={"PID": "int64", "age": "timedelta64[ns]", "risk": "Int64"})
263+
# Create a mapping between row in matrix and PID
264+
pids = X["PID"].unique()
265+
self.pid_to_row = {pid: i for i, pid in enumerate(pids)}
266+
267+
# Make the time bins
268+
days_per_month = 30
269+
bin_width = timedelta(days=self.months_per_bin * days_per_month)
270+
self.bins: npt.NDArray = np.arange(
271+
X["age"].min(),
272+
X["age"].max() + bin_width, # Add to ensure endpoint is included
273+
bin_width,
274+
)
275+
return self
276+
277+
def transform(self, X: pd.DataFrame) -> pd.DataFrame:
278+
out = X[["risk"]]
279+
out["row"] = X["PID"].apply(lambda pid: self.pid_to_row[pid])
280+
out["bin"] = pd.cut(
281+
X["age"], self.bins, right=False
282+
) # type: ignore[call-overload] # right=False indicates close left side
283+
284+
return out.groupby(["row", "bin"], as_index=False)["risk"].agg(
285+
self.risk_agg_method
286+
) # type: ignore[return-value] # as_index=False makes this a DataFrame

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "decipher"
3-
version = "0.1.9"
3+
version = "0.1.10"
44
description = "Utilities for Decipher"
55
authors = ["Thorvald Molthe Ballestad <[email protected]>"]
66
readme = "README.md"

tests/test_processing.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
read_raw_df,
1414
write_to_csv,
1515
)
16-
from decipher.processing.transformers import HPVResults, PersonStats
16+
from decipher.processing.transformers import HPVResults, ObservationMatrix, PersonStats
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -52,6 +52,25 @@ def test_read_and_hpv_pipeline():
5252
logger.debug(hpv_df)
5353

5454

55+
def test_observation_out():
56+
raw = read_raw_df(test_data_screening)
57+
58+
exam_pipeline = get_exam_pipeline(
59+
birthday_file=test_data_dob, drop_missing_birthday=True
60+
)
61+
exam_df = exam_pipeline.fit_transform(raw)
62+
observations = ObservationMatrix().fit_transform(exam_df)
63+
logger.info(observations)
64+
65+
assert {"bin", "row", "risk"} == set(observations)
66+
# Assert only one risk per person per time
67+
assert observations.value_counts(subset=["row", "bin"]).unique() == [1]
68+
69+
bins_intervals = observations["bin"].cat.categories
70+
assert bins_intervals[0].left <= exam_df["age"].min()
71+
assert bins_intervals[-1].right > exam_df["age"].max()
72+
73+
5574
def test_person_stats():
5675
raw = read_raw_df(test_data_screening)
5776

0 commit comments

Comments
 (0)