Skip to content

Commit ba14569

Browse files
committed
Fix memory consumption issue in observation transformer
1 parent 5cdfe18 commit ba14569

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

decipher/processing/transformers.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def fit(self, X: pd.DataFrame, y=None):
280280
CleanData(dtypes={"PID": "int64", "age": "timedelta64[ns]", "risk": "Int64"})
281281
# Create a mapping between row in matrix and PID
282282
pids = X["PID"].unique()
283-
self.pid_to_row = {pid: i for i, pid in enumerate(pids)}
283+
self.pid_to_row: dict[int, int] = {pid: i for i, pid in enumerate(pids)}
284284

285285
# Make the time bins
286286
days_per_month = 30
@@ -302,6 +302,9 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
302302
), # type: ignore[call-overload] # right=False indicates close left side
303303
}
304304
)
305-
return out.groupby(["row", "bin"], as_index=False)["risk"].agg(
305+
# The observed=True is important!
306+
# As the bin is categorical, observed=False will produce a cartesian product
307+
# between all possible bins and all rows. This eats up a lot of memory!
308+
return out.groupby(["row", "bin"], as_index=False, observed=True)["risk"].agg(
306309
self.risk_agg_method
307310
) # type: ignore[return-value] # as_index=False makes this a DataFrame

0 commit comments

Comments
 (0)