Skip to content

Commit 91da649

Browse files
committed
Make some changes to cohorts to keep compatibility between Analysis and Experience
Rebase, update citation and contributors
1 parent 75c0291 commit 91da649

File tree

4 files changed

+25
-1
lines changed

4 files changed

+25
-1
lines changed

CONTRIBUTORS.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ The _SCALPEL-Analysis_ package was initially implemented by researchers, develop
66

77
- Youcef Sebiat
88
- Maryan Morel
9-
- Dinh Phong Nguyen
9+
- Dinh Phong Nguyen
10+
- Dian Sun

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ If you use a library part of _SCALPEL3_ in a scientific publication, we would ap
237237
year={2020},
238238
publisher={Elsevier}
239239
}
240+
240241

241242
## Contributing
242243
The development cycle is opinionated. Each time you commit, git will

scalpel/core/cohort.py

+6
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,12 @@ def load(input: Dict) -> "Cohort":
284284
def from_description(description: str) -> "Cohort":
285285
raise NotImplementedError
286286

287+
def cache(self) -> "Cohort":
288+
self.subjects = self.subjects.cache()
289+
if self.events is not None:
290+
self.events = self.events.cache()
291+
return self
292+
287293

288294
def _union(a: Cohort, b: Cohort) -> Cohort:
289295
if a.events is None or b.events is None:

tests/core/cohort_test.py

+16
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,19 @@ def test_save_cohort(self, mock_method):
297297
},
298298
cohort_2.save_cohort("../../output"),
299299
)
300+
301+
def test_cache(self):
302+
patients_pd = pd.DataFrame({"patientID": [1, 2, 3]})
303+
events_pd = pd.DataFrame({"patientID": [1, 2, 3], "value": ["DP", "DAS", "DR"]})
304+
305+
patients = self.spark.createDataFrame(patients_pd)
306+
307+
events = self.spark.createDataFrame(events_pd)
308+
cohort = Cohort("liberal_fractures", "liberal_fractures", patients, events)
309+
cohort.cache()
310+
assert cohort.subjects.storageLevel.useMemory
311+
assert cohort.events.storageLevel.useMemory
312+
313+
def test_from_description(self):
314+
self.assertRaises(NotImplementedError, Cohort.from_description,
315+
description='some string')

0 commit comments

Comments
 (0)