|
1 | 1 | import logging
|
| 2 | +from datetime import timedelta |
2 | 3 | from pathlib import Path
|
3 |
| -from typing import Any |
| 4 | +from typing import Any, Callable |
4 | 5 |
|
5 | 6 | import numpy as np
|
| 7 | +import numpy.typing as npt |
6 | 8 | import pandas as pd
|
7 | 9 | from sklearn.base import BaseEstimator, TransformerMixin
|
8 | 10 |
|
@@ -244,3 +246,41 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
|
244 | 246 | )
|
245 | 247 | .dropna(subset="value")
|
246 | 248 | ).astype({"variable": "category", "value": "category"})
|
| 249 | + |
| 250 | + |
| 251 | +class ObservationMatrix(BaseEstimator, TransformerMixin): |
| 252 | + """Convert exams df to observations""" |
| 253 | + |
| 254 | + def __init__( |
| 255 | + self, risk_agg_method: str | Callable = "max", months_per_bin: float = 3 |
| 256 | + ): |
| 257 | + self.risk_agg_method = risk_agg_method |
| 258 | + self.months_per_bin = months_per_bin |
| 259 | + super().__init__() |
| 260 | + |
| 261 | + def fit(self, X: pd.DataFrame, y=None): |
| 262 | + CleanData(dtypes={"PID": "int64", "age": "timedelta64[ns]", "risk": "Int64"}) |
| 263 | + # Create a mapping between row in matrix and PID |
| 264 | + pids = X["PID"].unique() |
| 265 | + self.pid_to_row = {pid: i for i, pid in enumerate(pids)} |
| 266 | + |
| 267 | + # Make the time bins |
| 268 | + days_per_month = 30 |
| 269 | + bin_width = timedelta(days=self.months_per_bin * days_per_month) |
| 270 | + self.bins: npt.NDArray = np.arange( |
| 271 | + X["age"].min(), |
| 272 | + X["age"].max() + bin_width, # Add to ensure endpoint is included |
| 273 | + bin_width, |
| 274 | + ) |
| 275 | + return self |
| 276 | + |
| 277 | + def transform(self, X: pd.DataFrame) -> pd.DataFrame: |
| 278 | + out = X[["risk"]] |
| 279 | + out["row"] = X["PID"].apply(lambda pid: self.pid_to_row[pid]) |
| 280 | + out["bin"] = pd.cut( |
| 281 | + X["age"], self.bins, right=False |
| 282 | + ) # type: ignore[call-overload] # right=False indicates close left side |
| 283 | + |
| 284 | + return out.groupby(["row", "bin"], as_index=False)["risk"].agg( |
| 285 | + self.risk_agg_method |
| 286 | + ) # type: ignore[return-value] # as_index=False makes this a DataFrame |
0 commit comments