Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion neuracore/ml/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ max_prefetch_workers: 4

# Dataset synchronization
dataset_name: null
dataset_id: null
frequency: 10

# You can either specify input_data_types/output_data_types or
Expand Down
17 changes: 3 additions & 14 deletions neuracore/ml/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def _save_local_training_metadata(
"status": "RUNNING",
"algorithm": algorithm_name,
"algorithm_id": getattr(cfg, "algorithm_id", None),
"dataset_id": getattr(cfg, "dataset_id", None),
"dataset_name": getattr(cfg, "dataset_name", None),
"launch_time": time.time(),
"local_output_dir": str(output_dir),
Expand Down Expand Up @@ -587,25 +586,15 @@ def main(cfg: DictConfig) -> None:
"Neither 'algorithm' nor 'algorithm_id' is provided. " "Please specify one."
)

if cfg.dataset_id is None and cfg.dataset_name is None:
raise ValueError("Either 'dataset_id' or 'dataset_name' must be provided.")
if cfg.dataset_id is not None and cfg.dataset_name is not None:
raise ValueError(
"Both 'dataset_id' and 'dataset_name' are provided. "
"Please specify only one."
)
if cfg.dataset_name is None:
raise ValueError("'dataset_name' must be provided.")

# Login and get dataset
nc.login()
if cfg.org_id is not None:
nc.set_organization(cfg.org_id)

if cfg.dataset_id is not None:
dataset = nc.get_dataset(id=cfg.dataset_id)
elif cfg.dataset_name is not None:
dataset = nc.get_dataset(name=cfg.dataset_name)
else:
raise ValueError("Either 'dataset_id' or 'dataset_name' must be provided.")
dataset = nc.get_dataset(name=cfg.dataset_name)
dataset.cache_dir = _resolve_recording_cache_dir(cfg)
dataset.cache_dir.mkdir(parents=True, exist_ok=True)

Expand Down
86 changes: 18 additions & 68 deletions tests/unit/ml/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,18 +1257,9 @@ class TestMain:
(
{
"algorithm_id": "test-algorithm-id",
"dataset_id": None,
"dataset_name": None,
},
"Either 'dataset_id' or 'dataset_name' must be provided",
),
(
{
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": "test-dataset-name",
},
"Both 'dataset_id' and 'dataset_name' are provided",
"'dataset_name' must be provided",
),
],
)
Expand All @@ -1277,7 +1268,7 @@ def test_main_raises_validation_errors_for_invalid_configurations(
):
base_cfg = {
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": "test-dataset-name",
"local_output_dir": "/tmp/test",
"batch_size": 8,
"input_robot_data_spec": INPUT_ROBOT_DATA_SPEC,
Expand All @@ -1299,7 +1290,6 @@ def test_main_loads_dataset_by_name_when_dataset_name_provided(
):
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": None,
"dataset_name": "test-dataset-name",
"org_id": None,
"device": None,
Expand All @@ -1326,8 +1316,7 @@ def test_main_sets_organization_when_org_id_provided(
):
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": None,
"dataset_name": "test-dataset-name",
"org_id": "test-org-id",
"device": None,
"local_output_dir": str(temp_output_dir),
Expand All @@ -1346,7 +1335,7 @@ def test_main_sets_organization_when_org_id_provided(
main(cfg)

setup.mock_set_organization.assert_called_once_with("test-org-id")
setup.mock_get_dataset.assert_called_once_with(id="test-dataset-id")
setup.mock_get_dataset.assert_called_once_with(name="test-dataset-name")

def test_main_uses_algorithm_config_when_algorithm_provided_instead_of_algorithm_id(
self, monkeypatch, temp_output_dir
Expand All @@ -1356,8 +1345,7 @@ def test_main_uses_algorithm_config_when_algorithm_provided_instead_of_algorithm
"_target_": "tests.unit.ml.test_train.mock_model_class",
},
"algorithm_id": None,
"dataset_id": "test-dataset-id",
"dataset_name": None,
"dataset_name": "test-dataset-name",
"org_id": None,
"device": None,
"local_output_dir": str(temp_output_dir),
Expand All @@ -1381,8 +1369,7 @@ def test_main_uses_default_device_when_device_is_none(
):
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": None,
"dataset_name": "test-dataset-name",
"org_id": None,
"device": None,
"local_output_dir": str(temp_output_dir),
Expand All @@ -1408,8 +1395,7 @@ def test_main_uses_explicit_device_when_device_is_provided(
):
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": None,
"dataset_name": "test-dataset-name",
"org_id": None,
"device": "cuda:1",
"local_output_dir": str(temp_output_dir),
Expand Down Expand Up @@ -1437,8 +1423,7 @@ def test_main_uses_provided_batch_size_when_not_auto(
):
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": None,
"dataset_name": "test-dataset-name",
"org_id": None,
"device": None,
"local_output_dir": str(temp_output_dir),
Expand All @@ -1464,8 +1449,7 @@ def test_main_loads_algorithm_by_id_when_algorithm_not_in_cfg_but_algorithm_id_p
):
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": None,
"dataset_name": "test-dataset-name",
"org_id": None,
"device": None,
"local_output_dir": str(temp_output_dir),
Expand All @@ -1492,39 +1476,12 @@ def test_main_loads_algorithm_by_id_when_algorithm_not_in_cfg_but_algorithm_id_p
extract_dir=expected_extract_dir
)

def test_main_loads_dataset_by_id_when_dataset_id_provided_but_dataset_name_none(
self, monkeypatch, temp_output_dir
):
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": None,
"org_id": None,
"device": None,
"local_output_dir": str(temp_output_dir),
"batch_size": 8,
"input_robot_data_spec": INPUT_ROBOT_DATA_SPEC,
"output_robot_data_spec": OUTPUT_ROBOT_DATA_SPEC,
"output_prediction_horizon": 5,
"frequency": 30,
"algorithm_params": None,
"max_prefetch_workers": 4,
})

setup = MainTestSetup(monkeypatch)
setup.setup_mocks()

main(cfg)

setup.mock_get_dataset.assert_called_once_with(id="test-dataset-id")

def test_main_converts_string_batch_size_to_int_when_not_auto(
self, monkeypatch, temp_output_dir
):
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": None,
"dataset_name": "test-dataset-name",
"org_id": None,
"device": None,
"local_output_dir": str(temp_output_dir),
Expand Down Expand Up @@ -1557,8 +1514,7 @@ def test_main_uses_mp_spawn_for_distributed_training_when_world_size_greater_tha
):
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": None,
"dataset_name": "test-dataset-name",
"org_id": None,
"device": None,
"local_output_dir": str(temp_output_dir),
Expand Down Expand Up @@ -1597,8 +1553,7 @@ def test_main_uses_mp_spawn_for_distributed_training_when_world_size_greater_tha
def test_main_calls_setup_logging(self, monkeypatch, temp_output_dir):
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": None,
"dataset_name": "test-dataset-name",
"org_id": None,
"device": None,
"local_output_dir": str(temp_output_dir),
Expand Down Expand Up @@ -1626,8 +1581,7 @@ def test_main_saves_local_metadata_for_local_runs(
):
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": None,
"dataset_name": "test-dataset-name",
"org_id": "test-org",
"device": None,
"local_output_dir": str(temp_output_dir),
Expand All @@ -1649,7 +1603,7 @@ def test_main_saves_local_metadata_for_local_runs(
assert metadata_path.exists()
metadata = json.loads(metadata_path.read_text())
assert metadata["algorithm"] == "test-algorithm"
assert metadata["dataset_id"] == "test-dataset-id"
assert metadata["dataset_name"] == "test-dataset-name"
assert metadata["status"] == "RUNNING"
assert "JOINT_POSITIONS" in metadata["input_robot_data_spec"]["robot-id-1"]

Expand All @@ -1658,8 +1612,7 @@ def test_main_calls_dataset_synchronize_with_correct_parameters(
):
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": None,
"dataset_name": "test-dataset-name",
"org_id": None,
"device": None,
"local_output_dir": str(temp_output_dir),
Expand Down Expand Up @@ -1694,8 +1647,7 @@ def test_main_calls_dataset_synchronize_with_correct_parameters(
def test_main_uses_default_recording_cache_dir(self, monkeypatch, temp_output_dir):
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": None,
"dataset_name": "test-dataset-name",
"org_id": None,
"device": None,
"local_output_dir": str(temp_output_dir),
Expand All @@ -1720,8 +1672,7 @@ def test_main_uses_custom_recording_cache_dir(self, monkeypatch, temp_output_dir
custom_cache_dir = temp_output_dir / "custom-recording-cache"
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": None,
"dataset_name": "test-dataset-name",
"org_id": None,
"device": None,
"local_output_dir": str(temp_output_dir),
Expand All @@ -1747,8 +1698,7 @@ def test_main_uses_autotuning_when_batch_size_is_auto(
):
cfg = OmegaConf.create({
"algorithm_id": "test-algorithm-id",
"dataset_id": "test-dataset-id",
"dataset_name": None,
"dataset_name": "test-dataset-name",
"org_id": None,
"device": None,
"local_output_dir": str(temp_output_dir),
Expand Down