From 21b8e01aa551b6362bf2fa24c92b9f80e6b2b366 Mon Sep 17 00:00:00 2001 From: Bhimraj Yadav Date: Fri, 22 Aug 2025 08:19:01 +0000 Subject: [PATCH 1/2] feat(resolver): add support for resolving directories in lightning storage --- src/litdata/streaming/resolver.py | 25 ++++++++++++++++++ tests/streaming/test_resolver.py | 44 +++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/src/litdata/streaming/resolver.py b/src/litdata/streaming/resolver.py index 51db59528..4b44349ab 100644 --- a/src/litdata/streaming/resolver.py +++ b/src/litdata/streaming/resolver.py @@ -89,6 +89,9 @@ def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir: if dir_path_absolute.startswith("/teamspace/gcs_folders") and len(dir_path_absolute.split("/")) > 3: return _resolve_gcs_folders(dir_path_absolute) + if dir_path_absolute.startswith("/teamspace/lightning_storage") and len(dir_path_absolute.split("/")) > 3: + return _resolve_lightning_storage(dir_path_absolute) + if dir_path_absolute.startswith("/teamspace/datasets") and len(dir_path_absolute.split("/")) > 3: return _resolve_datasets(dir_path_absolute) @@ -246,6 +249,28 @@ def _resolve_gcs_folders(dir_path: str) -> Dir: return Dir(path=dir_path, url=os.path.join(data_connection[0].gcs_folder.source, *dir_path.split("/")[4:])) +def _resolve_lightning_storage(dir_path: str) -> Dir: + from lightning_sdk.lightning_cloud.rest_client import LightningClient + + client = LightningClient(max_tries=2) + + # Get the ids from env variables + project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) + if project_id is None: + raise RuntimeError("The `LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables.") + + target_name = dir_path.split("/")[3] + + data_connections = client.data_connection_service_list_data_connections(project_id).data_connections + + data_connection = [dc for dc in data_connections if dc.name == target_name] + + if not data_connection: + raise ValueError(f"We didn't find any matching data connection with the provided name `{target_name}`.") + + return Dir(path=dir_path, url=os.path.join(data_connection[0].r2.source, *dir_path.split("/")[4:])) + + def _resolve_datasets(dir_path: str) -> Dir: from lightning_sdk.lightning_cloud.rest_client import LightningClient diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py index 33162a6ae..54147e3c8 100644 --- a/tests/streaming/test_resolver.py +++ b/tests/streaming/test_resolver.py @@ -465,3 +465,47 @@ def test_src_resolver_gcs_folders(monkeypatch, lightning_cloud_mock): assert resolver._resolve_dir("/teamspace/gcs_folders/debug_folder/a/b/c").url == expected + "/a/b/c" auth.clear() + + +@pytest.mark.skipif(sys.platform == "win32", reason="windows isn't supported") +def test_src_resolver_lightning_storage(monkeypatch, lightning_cloud_mock): + """Test lightning_storage resolver with r2 source.""" + auth = login.Auth() + auth.save(user_id="7c8455e3-7c5f-4697-8a6d-105971d6b9bd", api_key="e63fae57-2b50-498b-bc46-d6204cbf330e") + + with pytest.raises( + RuntimeError, match="`LIGHTNING_CLOUD_PROJECT_ID` couldn't be found from the environment variables." + ): + resolver._resolve_dir("/teamspace/lightning_storage/my_dataset") + + monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "project_id") + + client_mock = mock.MagicMock() + client_mock.data_connection_service_list_data_connections.return_value = V1ListDataConnectionsResponse( + data_connections=[ + V1DataConnection(name="my_dataset", r2=mock.MagicMock(source="r2://my-r2-bucket")) + ], + ) + + client_cls_mock = mock.MagicMock() + client_cls_mock.return_value = client_mock + lightning_cloud_mock.rest_client.LightningClient = client_cls_mock + + expected = "r2://my-r2-bucket" + assert resolver._resolve_dir("/teamspace/lightning_storage/my_dataset").url == expected + assert resolver._resolve_dir("/teamspace/lightning_storage/my_dataset/train").url == expected + "/train" + + # Test missing data connection + client_mock = mock.MagicMock() + client_mock.data_connection_service_list_data_connections.return_value = V1ListDataConnectionsResponse( + data_connections=[], + ) + + client_cls_mock = mock.MagicMock() + client_cls_mock.return_value = client_mock + lightning_cloud_mock.rest_client.LightningClient = client_cls_mock + + with pytest.raises(ValueError, match="name `my_dataset`"): + resolver._resolve_dir("/teamspace/lightning_storage/my_dataset") + + auth.clear() \ No newline at end of file From da2c81c96b96cc97f57f256886042e2af864e9e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 22 Aug 2025 08:21:50 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/streaming/test_resolver.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/streaming/test_resolver.py b/tests/streaming/test_resolver.py index 54147e3c8..517f4205b 100644 --- a/tests/streaming/test_resolver.py +++ b/tests/streaming/test_resolver.py @@ -482,9 +482,7 @@ def test_src_resolver_lightning_storage(monkeypatch, lightning_cloud_mock): client_mock = mock.MagicMock() client_mock.data_connection_service_list_data_connections.return_value = V1ListDataConnectionsResponse( - data_connections=[ - V1DataConnection(name="my_dataset", r2=mock.MagicMock(source="r2://my-r2-bucket")) - ], + data_connections=[V1DataConnection(name="my_dataset", r2=mock.MagicMock(source="r2://my-r2-bucket"))], ) client_cls_mock = mock.MagicMock() @@ -508,4 +506,4 @@ def test_src_resolver_lightning_storage(monkeypatch, lightning_cloud_mock): with pytest.raises(ValueError, match="name `my_dataset`"): resolver._resolve_dir("/teamspace/lightning_storage/my_dataset") - auth.clear() \ No newline at end of file + auth.clear()