Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bsweger/get task id values #22

Merged
merged 6 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion src/hubverse_transform/hub_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# mypy: disable-error-code="operator"
# mypy: disable-error-code="operator, arg-type"

import itertools
import json
import os
from datetime import date

from cloudpathlib import AnyPath

Expand Down Expand Up @@ -46,6 +48,30 @@ def __repr__(self):
def __str__(self):
return f"Hubverse config information for {self.hub_name}."

def get_task_id_values(self) -> dict[str, set]:
"""
Return a dict of hub task ids and required and optional values across all rounds.

Returns
-------
model_tasks : dict[str, list]
A mapping of tasks ids to their possible values, as
defined in a hub's tasks.json configuration file.
"""

tasks = self.tasks
rounds = tasks.get("rounds", [])

model_tasks_dict: dict[str, set] = dict()

for r in rounds:
for task_set in r.get("model_tasks", []):
task_values = self._get_task_id_values(task_set)
for key, value in task_values.items():
model_tasks_dict[key] = model_tasks_dict.get(key, set()) | value

return model_tasks_dict

def _get_admin_config(self, admin_file: str):
"""Read a Hubverse hub's admin configuration file."""
admin_path = self.hub_config_path / admin_file
Expand All @@ -65,3 +91,31 @@ def _get_tasks_config(self, tasks_file: str):

with tasks_path.open() as f:
return json.loads(f.read())

def _get_task_id_values(self, task_set: dict) -> dict[str, set]:
"""Return a dict of ids and values for a specific modeling task."""
task_id_values = dict()

# create a dictionary of all tasks ids and values for this task
task_ids = {task_id[0]: task_id[1] for task_id in task_set.get("task_ids", {}).items()}

# flatten the dictionary values for each task_id (i.e., combine "required" and "optional" lists)
for task_id in task_ids.items():
task_id_values[task_id[0]] = set(
itertools.chain.from_iterable([value or [] for value in task_id[1].values()])
)

return task_id_values

def _get_data_type(self, value: int | bool | str | date | float) -> type:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got ahead of myself here: this is the helper that will eventually determine the data type of our task_ids and output_type_ids, but it's not being used yet.

"""Return the data type of a value."""
data_type = type(value)

if data_type == str:
try:
date.fromisoformat(value)
data_type = date
except ValueError:
pass

return data_type
58 changes: 54 additions & 4 deletions test/unit/test_hub_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from datetime import date

import pytest
from cloudpathlib import AnyPath
Expand Down Expand Up @@ -57,11 +58,11 @@ def tasks_config() -> dict:
},
{
"task_ids": {
"reference_date": {"required": None, "optional": ["2024-07-13", "2024-07-21"]},
"reference_date": {"required": None, "optional": ["2024-08-13", "2024-07-21"]},
"target": {"required": ["wk number assimilations"]},
"horizon": {"required": ["one", "two", "three"]},
"location": {"required": ["Earth", "Vulcan", "789"], "optional": ["Ryza", "123"]},
"target_end_date": {"required": None, "optional": ["2024-07-20", "2024-07-27"]},
"location": {"required": ["Earth", "Bajor", "123"], "optional": []},
"target_end_date": {"required": ["1999-12-31"], "optional": ["2024-08-20", "2024-07-27"]},
},
"output_type": {
"quantile": {
Expand All @@ -72,7 +73,25 @@ def tasks_config() -> dict:
},
],
"submissions_due": {"relative_to": "reference_date", "start": -6, "end": -3},
}
},
{
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added another round to our test data so we can test the new function for multiple rounds.

"round_name": "Round 2",
"model_tasks": [
{
"task_ids": {
"age": {"required": [10, 20, 30], "optional": [40, 50]},
"target": {"required": ["starfleet entrance exam score"]},
},
"output_type": {
"mean": {
"output_type_id": {"required": ["NA"], "optional": None},
"value": {"type": "double", "minimum": 0, "maximum": 1},
}
},
}
],
"submissions_due": {"relative_to": "reference_date", "start": -100, "end": 5},
},
],
}

Expand Down Expand Up @@ -128,3 +147,34 @@ def test_hub_missing_tasks_config(hubverse_hub):
with pytest.raises(FileNotFoundError):
hc = HubConfig(hub_path, tasks_file="missing-tasks.json")
print(hc)


def test_get_task_id_values(hubverse_hub):
hub_path = AnyPath(hubverse_hub)
hc = HubConfig(hub_path)

assert hc.get_task_id_values() == {
"reference_date": {"2024-07-13", "2024-07-21", "2024-08-13"},
"target": {"borg growth rate change", "wk number assimilations", "starfleet entrance exam score"},
"horizon": {-1, 0, 1, 2, 3, "one", "two", "three"},
"location": {"Earth", "Vulcan", "789", "Ryza", "123", "Bajor"},
"target_end_date": {"2024-07-20", "2024-07-27", "1999-12-31", "2024-08-20"},
"age": {10, 20, 30, 40, 50},
}


@pytest.mark.parametrize(
"value, expected_type",
[
("a string", str),
(123, int),
(123.45, float),
("2024-07-13", date),
(False, bool),
],
)
def test_get_data_type(hubverse_hub, value, expected_type):
hub_path = AnyPath(hubverse_hub)
hc = HubConfig(hub_path)

assert hc._get_data_type(value) == expected_type