-
Notifications
You must be signed in to change notification settings - Fork 0
Bsweger/get task id values #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bdadf63
813b73c
8da9a7b
e9adf0b
d36ab94
cabfe2e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
# mypy: disable-error-code="operator" | ||
# mypy: disable-error-code="operator, arg-type" | ||
|
||
import itertools | ||
import json | ||
import os | ||
from datetime import date | ||
|
||
from cloudpathlib import AnyPath | ||
|
||
|
@@ -46,6 +48,30 @@ def __repr__(self): | |
def __str__(self): | ||
return f"Hubverse config information for {self.hub_name}." | ||
|
||
def get_task_id_values(self) -> dict[str, set]: | ||
""" | ||
Return a dict of hub task ids and required and optional values across all rounds. | ||
|
||
Returns | ||
------- | ||
model_tasks : dict[str, list] | ||
A mapping of tasks ids to their possible values, as | ||
defined in a hub's tasks.json configuration file. | ||
""" | ||
|
||
tasks = self.tasks | ||
rounds = tasks.get("rounds", []) | ||
|
||
model_tasks_dict: dict[str, set] = dict() | ||
|
||
for r in rounds: | ||
for task_set in r.get("model_tasks", []): | ||
task_values = self._get_task_id_values(task_set) | ||
for key, value in task_values.items(): | ||
model_tasks_dict[key] = model_tasks_dict.get(key, set()) | value | ||
|
||
return model_tasks_dict | ||
|
||
def _get_admin_config(self, admin_file: str): | ||
"""Read a Hubverse hub's admin configuration file.""" | ||
admin_path = self.hub_config_path / admin_file | ||
|
@@ -65,3 +91,31 @@ def _get_tasks_config(self, tasks_file: str): | |
|
||
with tasks_path.open() as f: | ||
return json.loads(f.read()) | ||
|
||
def _get_task_id_values(self, task_set: dict) -> dict[str, set]: | ||
"""Return a dict of ids and values for a specific modeling task.""" | ||
task_id_values = dict() | ||
|
||
# create a dictionary of all tasks ids and values for this task | ||
task_ids = {task_id[0]: task_id[1] for task_id in task_set.get("task_ids", {}).items()} | ||
|
||
# flatten the dictionary values for each task_id (i.e., combine "required" and "optional" lists) | ||
for task_id in task_ids.items(): | ||
task_id_values[task_id[0]] = set( | ||
itertools.chain.from_iterable([value or [] for value in task_id[1].values()]) | ||
) | ||
|
||
return task_id_values | ||
|
||
def _get_data_type(self, value: int | bool | str | date | float) -> type: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got ahead of myself here: this is the helper that will eventually determine the data type of our task_ids and output_type_ids, but it's not being used yet. |
||
"""Return the data type of a value.""" | ||
data_type = type(value) | ||
|
||
if data_type == str: | ||
try: | ||
date.fromisoformat(value) | ||
data_type = date | ||
except ValueError: | ||
pass | ||
|
||
return data_type |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import json | ||
from datetime import date | ||
|
||
import pytest | ||
from cloudpathlib import AnyPath | ||
|
@@ -57,11 +58,11 @@ def tasks_config() -> dict: | |
}, | ||
{ | ||
"task_ids": { | ||
"reference_date": {"required": None, "optional": ["2024-07-13", "2024-07-21"]}, | ||
"reference_date": {"required": None, "optional": ["2024-08-13", "2024-07-21"]}, | ||
"target": {"required": ["wk number assimilations"]}, | ||
"horizon": {"required": ["one", "two", "three"]}, | ||
"location": {"required": ["Earth", "Vulcan", "789"], "optional": ["Ryza", "123"]}, | ||
"target_end_date": {"required": None, "optional": ["2024-07-20", "2024-07-27"]}, | ||
"location": {"required": ["Earth", "Bajor", "123"], "optional": []}, | ||
"target_end_date": {"required": ["1999-12-31"], "optional": ["2024-08-20", "2024-07-27"]}, | ||
}, | ||
"output_type": { | ||
"quantile": { | ||
|
@@ -72,7 +73,25 @@ def tasks_config() -> dict: | |
}, | ||
], | ||
"submissions_due": {"relative_to": "reference_date", "start": -6, "end": -3}, | ||
} | ||
}, | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added another round to our test data so we can test the new function for multiple rounds. |
||
"round_name": "Round 2", | ||
elray1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"model_tasks": [ | ||
{ | ||
"task_ids": { | ||
"age": {"required": [10, 20, 30], "optional": [40, 50]}, | ||
"target": {"required": ["starfleet entrance exam score"]}, | ||
}, | ||
"output_type": { | ||
"mean": { | ||
"output_type_id": {"required": ["NA"], "optional": None}, | ||
"value": {"type": "double", "minimum": 0, "maximum": 1}, | ||
} | ||
}, | ||
} | ||
], | ||
"submissions_due": {"relative_to": "reference_date", "start": -100, "end": 5}, | ||
}, | ||
], | ||
} | ||
|
||
|
@@ -128,3 +147,34 @@ def test_hub_missing_tasks_config(hubverse_hub): | |
with pytest.raises(FileNotFoundError): | ||
hc = HubConfig(hub_path, tasks_file="missing-tasks.json") | ||
print(hc) | ||
|
||
|
||
def test_get_task_id_values(hubverse_hub): | ||
hub_path = AnyPath(hubverse_hub) | ||
hc = HubConfig(hub_path) | ||
|
||
assert hc.get_task_id_values() == { | ||
"reference_date": {"2024-07-13", "2024-07-21", "2024-08-13"}, | ||
"target": {"borg growth rate change", "wk number assimilations", "starfleet entrance exam score"}, | ||
"horizon": {-1, 0, 1, 2, 3, "one", "two", "three"}, | ||
"location": {"Earth", "Vulcan", "789", "Ryza", "123", "Bajor"}, | ||
"target_end_date": {"2024-07-20", "2024-07-27", "1999-12-31", "2024-08-20"}, | ||
"age": {10, 20, 30, 40, 50}, | ||
} | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"value, expected_type", | ||
[ | ||
("a string", str), | ||
(123, int), | ||
(123.45, float), | ||
("2024-07-13", date), | ||
(False, bool), | ||
], | ||
) | ||
def test_get_data_type(hubverse_hub, value, expected_type): | ||
hub_path = AnyPath(hubverse_hub) | ||
hc = HubConfig(hub_path) | ||
|
||
assert hc._get_data_type(value) == expected_type |
Uh oh!
There was an error while loading. Please reload this page.