Skip to content

Commit

Permalink
python/pytorch: Refactor datasets into separate files
Browse files Browse the repository at this point in the history
Signed-off-by: Soham Manoli <[email protected]>
  • Loading branch information
msoham123 committed Jun 13, 2024
1 parent 3e21e3e commit 68e98f4
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 145 deletions.
5 changes: 4 additions & 1 deletion python/aistore/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@
AISFileLoaderIterDataPipe as AISFileLoader,
AISSourceLister,
)
from aistore.pytorch.dataset import AISDataset, AISIterDataset, AISMultiShardStream

from aistore.pytorch.dataset import AISDataset
from aistore.pytorch.multishard_dataset import AISMultiShardStream
from aistore.pytorch.iter_dataset import AISIterDataset
66 changes: 66 additions & 0 deletions python/aistore/pytorch/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Base classes for AIS Datasets and Iterable Datasets
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""

from typing import List, Union
from aistore.sdk.ais_source import AISSource
from aistore.pytorch.utils import list_objects, list_objects_iterator
from aistore.sdk import Client


class AISBaseClass:
"""
A base class for creating AIS Datasets for PyTorch.
Args:
client_url (str): AIS endpoint URL
urls_list (Union[str, List[str]]): Single or list of URL prefixes to load data
ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data
"""

def __init__(
self,
client_url: str,
urls_list: Union[str, List[str]],
ais_source_list: Union[AISSource, List[AISSource]],
) -> None:
self.client = Client(client_url)
if isinstance(urls_list, str):
urls_list = [urls_list]
if isinstance(ais_source_list, AISSource):
ais_source_list = [ais_source_list]
self._objects = list_objects(self.client, urls_list, ais_source_list)


class AISBaseClassIter:
"""
A base class for creating AIS Iterable Datasets for PyTorch.
Args:
client_url (str): AIS endpoint URL
urls_list (Union[str, List[str]]): Single or list of URL prefixes to load data
ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data
"""

def __init__(
self,
client_url: str,
urls_list: Union[str, List[str]],
ais_source_list: Union[AISSource, List[AISSource]],
) -> None:
self.client = Client(client_url)
if isinstance(urls_list, str):
urls_list = [urls_list]
if isinstance(ais_source_list, AISSource):
ais_source_list = [ais_source_list]
self.urls_list = urls_list
self.ais_source_list = ais_source_list
self._reset_iterator()

def _reset_iterator(self):
"""Reset the object iterator to start from the beginning"""
self._object_iter = list_objects_iterator(
self.client, self.urls_list, self.ais_source_list
)
143 changes: 4 additions & 139 deletions python/aistore/pytorch/dataset.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,13 @@
"""
AIS Plugin for PyTorch
PyTorch Dataset and DataLoader for AIS.
PyTorch Dataset for AIS.
Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
"""

from typing import Iterator, List, Union
from torch.utils.data import Dataset, IterableDataset

from aistore.sdk import Client
from typing import List, Union
from torch.utils.data import Dataset
from aistore.sdk.ais_source import AISSource
from aistore.sdk.dataset.data_shard import DataShard
from aistore.pytorch.utils import (
list_objects,
list_objects_iterator,
list_shard_objects_iterator,
)


class AISBaseClass:
"""
A base class for creating AIS Datasets for PyTorch.
Args:
client_url (str): AIS endpoint URL
urls_list (Union[str, List[str]]): Single or list of URL prefixes to load data
ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data
"""

def __init__(
self,
client_url: str,
urls_list: Union[str, List[str]],
ais_source_list: Union[AISSource, List[AISSource]],
) -> None:
self.client = Client(client_url)
if isinstance(urls_list, str):
urls_list = [urls_list]
if isinstance(ais_source_list, AISSource):
ais_source_list = [ais_source_list]
self._objects = list_objects(self.client, urls_list, ais_source_list)
from aistore.pytorch.base_dataset import AISBaseClass


class AISDataset(AISBaseClass, Dataset):
Expand Down Expand Up @@ -79,105 +46,3 @@ def __getitem__(self, index: int):
obj = self._objects[index]
content = obj.get(etl_name=self.etl_name).read_all()
return obj.name, content


class AISBaseClassIter:
"""
A base class for creating AIS Iterable Datasets for PyTorch.
Args:
client_url (str): AIS endpoint URL
urls_list (Union[str, List[str]]): Single or list of URL prefixes to load data
ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data
"""

def __init__(
self,
client_url: str,
urls_list: Union[str, List[str]],
ais_source_list: Union[AISSource, List[AISSource]],
) -> None:
self.client = Client(client_url)
if isinstance(urls_list, str):
urls_list = [urls_list]
if isinstance(ais_source_list, AISSource):
ais_source_list = [ais_source_list]
self.urls_list = urls_list
self.ais_source_list = ais_source_list
self._reset_iterator()

def _reset_iterator(self):
"""Reset the object iterator to start from the beginning"""
self._object_iter = list_objects_iterator(
self.client, self.urls_list, self.ais_source_list
)


class AISIterDataset(AISBaseClassIter, IterableDataset):
"""
An iterable-style dataset that iterates over objects in AIS.
If `etl_name` is provided, that ETL must already exist on the AIStore cluster.
Args:
client_url (str): AIS endpoint URL
urls_list (Union[str, List[str]]): Single or list of URL prefixes to load data
ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data
etl_name (str, optional): Optional ETL on the AIS cluster to apply to each object
Note:
Each object is represented as a tuple of object_name (str) and object_content (bytes)
"""

def __init__(
self,
client_url: str,
urls_list: Union[str, List[str]] = [],
ais_source_list: Union[AISSource, List[AISSource]] = [],
etl_name: str = None,
):
if not urls_list and not ais_source_list:
raise ValueError(
"At least one of urls_list or ais_source_list must be provided."
)
super().__init__(client_url, urls_list, ais_source_list)
self.etl_name = etl_name
self.length = None

def __iter__(self):
self._reset_iterator()
for obj in self._object_iter:
obj_name = obj.name
content = obj.get(etl_name=self.etl_name).read_all()
yield obj_name, content

def __len__(self):
if self.length is None:
self._reset_iterator()
self.length = self._calculate_len()
return self.length

def _calculate_len(self):
return sum(1 for _ in self._object_iter)


class AISMultiShardStream(IterableDataset):
"""
An iterable-style dataset that iterates over multiple shard streams and yields combined samples.
Args:
data_sources (List[DataShard]): List of DataShard objects
Returns:
Iterable: Iterable over the combined samples, where each sample is a tuple of
one object bytes from each shard stream
"""

def __init__(self, data_sources: List[DataShard]):
self.data_sources = data_sources

def __iter__(self) -> Iterator:
data_iterators = (
list_shard_objects_iterator(ds.bucket, ds.prefix, ds.etl_name)
for ds in self.data_sources
)
return zip(*data_iterators)
57 changes: 57 additions & 0 deletions python/aistore/pytorch/iter_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Iterable Dataset for AIS
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""

from aistore.pytorch.base_dataset import AISBaseClassIter
from torch.utils.data import IterableDataset
from typing import List, Union
from aistore.sdk.ais_source import AISSource


class AISIterDataset(AISBaseClassIter, IterableDataset):
"""
An iterable-style dataset that iterates over objects in AIS.
If `etl_name` is provided, that ETL must already exist on the AIStore cluster.
Args:
client_url (str): AIS endpoint URL
urls_list (Union[str, List[str]]): Single or list of URL prefixes to load data
ais_source_list (Union[AISSource, List[AISSource]]): Single or list of AISSource objects to load data
etl_name (str, optional): Optional ETL on the AIS cluster to apply to each object
Note:
Each object is represented as a tuple of object_name (str) and object_content (bytes)
"""

def __init__(
self,
client_url: str,
urls_list: Union[str, List[str]] = [],
ais_source_list: Union[AISSource, List[AISSource]] = [],
etl_name: str = None,
):
if not urls_list and not ais_source_list:
raise ValueError(
"At least one of urls_list or ais_source_list must be provided."
)
super().__init__(client_url, urls_list, ais_source_list)
self.etl_name = etl_name
self.length = None

def __iter__(self):
self._reset_iterator()
for obj in self._object_iter:
obj_name = obj.name
content = obj.get(etl_name=self.etl_name).read_all()
yield obj_name, content

def __len__(self):
if self.length is None:
self._reset_iterator()
self.length = self._calculate_len()
return self.length

def _calculate_len(self):
return sum(1 for _ in self._object_iter)
33 changes: 33 additions & 0 deletions python/aistore/pytorch/multishard_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Multishard Stream Dataset for AIS.
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""

from torch.utils.data import IterableDataset
from aistore.sdk.dataset.data_shard import DataShard
from typing import Iterator, List
from aistore.pytorch.utils import list_shard_objects_iterator


class AISMultiShardStream(IterableDataset):
"""
An iterable-style dataset that iterates over multiple shard streams and yields combined samples.
Args:
data_sources (List[DataShard]): List of DataShard objects
Returns:
Iterable: Iterable over the combined samples, where each sample is a tuple of
one object bytes from each shard stream
"""

def __init__(self, data_sources: List[DataShard]):
self.data_sources = data_sources

def __iter__(self) -> Iterator:
data_iterators = (
list_shard_objects_iterator(ds.bucket, ds.prefix, ds.etl_name)
for ds in self.data_sources
)
return zip(*data_iterators)
12 changes: 7 additions & 5 deletions python/tests/unit/pytorch/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import unittest
from unittest.mock import patch, Mock, MagicMock
from aistore.pytorch.dataset import AISDataset, AISIterDataset, AISMultiShardStream
from aistore.pytorch.dataset import AISDataset
from aistore.pytorch.multishard_dataset import AISMultiShardStream
from aistore.pytorch.iter_dataset import AISIterDataset


class TestAISDataset(unittest.TestCase):
Expand All @@ -15,14 +17,14 @@ def setUp(self) -> None:
]

self.patcher_list_objects_iterator = patch(
"aistore.pytorch.dataset.list_objects_iterator",
"aistore.pytorch.base_dataset.list_objects_iterator",
return_value=iter(self.mock_objects),
)
self.patcher_list_objects = patch(
"aistore.pytorch.dataset.list_objects", return_value=self.mock_objects
"aistore.pytorch.base_dataset.list_objects", return_value=self.mock_objects
)
self.patcher_client = patch(
"aistore.pytorch.dataset.Client", return_value=self.mock_client
"aistore.pytorch.base_dataset.Client", return_value=self.mock_client
)
self.patcher_list_objects_iterator.start()
self.patcher_list_objects.start()
Expand Down Expand Up @@ -53,7 +55,7 @@ def test_iter_dataset(self):

def test_multi_shard_stream(self):
self.patcher = unittest.mock.patch(
"aistore.pytorch.dataset.list_shard_objects_iterator"
"aistore.pytorch.multishard_dataset.list_shard_objects_iterator"
)
self.mock_list_shard_objects_iterator = self.patcher.start()

Expand Down

0 comments on commit 68e98f4

Please sign in to comment.