diff --git a/README.md b/README.md index 3388b13..180d46d 100644 --- a/README.md +++ b/README.md @@ -109,24 +109,91 @@ Example: arcee.stage("preparing") ``` -## Logging datasets -To log a dataset, use the `dataset` method with the following parameters: -- path (str, required): the dataset path. +## Datasets +### Logging +Logging a dataset allows you to create a dataset or a new version of +the dataset if the dataset has already been created, but has been changed. +To create a dataset, use the `Dataset` class with the following parameters: + +Dataset parameters: +- key (str, required): the unique dataset key. - name (str, optional): the dataset name. - description (str, optional): the dataset description. - labels (list, optional): the dataset labels. + +Version parameters: +- aliases (list, optional): the list of aliases for this version. +- meta (dict, optional): the dataset version meta. +- timespan_from (int, optional): the dataset version timespan from. +- timespan_to (int, optional): the dataset version timespan to. +```sh +dataset = arcee.Dataset(key='YOUR-DATASET-KEY', + name='YOUR-DATASET-NAME', + description="YOUR-DATASET-DESCRIPTION", + ... + ) +dataset.labels = ["YOUR-DATASET-LABEL-1", "YOUR-DATASET-LABEL-2"] +dataset.aliases = ['YOUR-VERSION-ALIAS'] +``` +To log a dataset, use the `log_dataset` method with the following parameters: +- dataset (Dataset, required): the dataset object. +- comment (str, optional): the usage comment. +```sh +arcee.log_dataset(dataset=dataset, comment='LOGGING_COMMENT') +``` + +### Using +To use a dataset, use the `use_dataset` method with dataset `key:version`. +Parameters: +- dataset (str, required): the dataset indentifier in key:version format. +- comment (str, optional): the usage comment. +```sh +dataset = arcee.use_dataset( + dataset='YOUR-DATASET-KEY:YOUR-DATASET-VERSION-OR-ALIAS') +``` + +### Adding files and downloading +You can add or remove files from dataset and download it as well. +Supported file paths: +- `file://` - the local files. +- `s3://` - the amazon S3 files. + +adding / removing files + +local: +```sh +dataset.remove_file(path='file://LOCAL_PATH_TO_FILE_1') +dataset.add_file(path='file://LOCAL_PATH_TO_FILE_2') +arcee.log_dataset(dataset=dataset) +``` +s3: ```sh -arcee.dataset(path="YOUR-DATASET-PATH", - name="YOUR-DATASET-NAME", - description="YOUR-DATASET-DESCRIPTION", - labels=["YOUR-DATASET-LABEL-1", "YOUR-DATASET-LABEL-2"]) +os.environ['AWS_ACCESS_KEY_ID'] = 'AWS_ACCESS_KEY_ID' +os.environ['AWS_SECRET_ACCESS_KEY'] = 'AWS_SECRET_ACCESS_KEY' +dataset.remove_file(path='s3://BUCKET/PATH_1') +dataset.add_file(path='s3://BUCKET/PATH_2') +arcee.log_dataset(dataset=dataset) +``` +downloading: +Parameters: +- overwrite (bool, optional): overwrite an existing dataset or skip +downloading if it already exists. +```sh +dataset.download(overwrite=True) ``` Example: ```sh -arcee.dataset("https://s3/ml-bucket/datasets/training_dataset.csv", - name="Training dataset", - description="Training dataset (100k rows)", - labels=["training", "100k"]) +# use version v0, v1 etc, or any version alias: my_dataset:latest +dataset = arcee.use_dataset(dataset='my_dataset:V0') +path_map = dataset.download() +for local_path in path_map.values(): + with open(local_path, 'r'): + # read downloaded file + +new_dataset = arcee.Dataset('new_dataset') +new_dataset.add_file(path='s3://ml-bucket/datasets/training_dataset.csv') +arcee.log_dataset(dataset=new_dataset) +new_dataset.download() ``` ## Creating models @@ -221,4 +288,4 @@ arcee.finish() To fail a run, use the `error` method. ```sh arcee.error() -``` \ No newline at end of file +``` diff --git a/examples/datasets.py b/examples/datasets.py new file mode 100644 index 0000000..628c3e8 --- /dev/null +++ b/examples/datasets.py @@ -0,0 +1,21 @@ +import kiroframe_arcee as arcee + +with arcee.init("test", "simple"): + dataset = arcee.Dataset(key='test_dataset', description='test dataset') + + # adding file + dataset.add_file(path=__file__) + + # log new dataset_version (with file) + arcee.log_dataset(dataset) + print('Actual dataset version: ', dataset._version) + + # downloading + dataset.download() + + # remove file from dataset + dataset.remove_file(path=__file__) + + # log new dataset_version (without files) + arcee.log_dataset(dataset) + print('Actual dataset version: ', dataset._version) diff --git a/kiroframe_arcee/__init__.py b/kiroframe_arcee/__init__.py index 09a755f..c74bbce 100644 --- a/kiroframe_arcee/__init__.py +++ b/kiroframe_arcee/__init__.py @@ -1,5 +1,5 @@ # flake8: noqa: F401 from .arcee import (init, send, tag, milestone, info, finish, error, stage, - dataset, hyperparam, model, model_version, - model_version_alias, model_version_tag, artifact, - artifact_tag) + hyperparam, model, model_version, model_version_alias, + model_version_tag, artifact, artifact_tag, Dataset, + log_dataset, use_dataset) diff --git a/kiroframe_arcee/arcee.py b/kiroframe_arcee/arcee.py index 9fc0993..f45c94d 100644 --- a/kiroframe_arcee/arcee.py +++ b/kiroframe_arcee/arcee.py @@ -9,6 +9,7 @@ acquire_console, release_console) from kiroframe_arcee.name_generator import NameGenerator from kiroframe_arcee.utils import single +from kiroframe_arcee.modules.dataset import Dataset class Job(threading.Thread): @@ -49,7 +50,6 @@ def __init__( self._tags = dict() self._name = None self._hyperparams = dict() - self._dataset = None self._model = None self._model_version = None self._model_version_tags = dict() @@ -90,14 +90,6 @@ def hyperparams(self, value): k, v = value self._hyperparams.update({k: v}) - @property - def dataset(self): - return self._dataset - - @dataset.setter - def dataset(self, value): - self._dataset = value - def __enter__(self): return self @@ -225,14 +217,29 @@ def stage(name): asyncio.run(arcee.sender.create_stage(arcee.run, arcee.token, name)) -def dataset(path, name=None, description=None, labels=None): +def log_dataset(dataset: Dataset, comment: str = None): arcee = Arcee() - if arcee.dataset is None: - arcee.dataset = path - asyncio.run(arcee.sender.register_dataset( - arcee.token, arcee.run, arcee.name, arcee.task_key, path, name, - description, labels + if dataset: + dataset.wait_ready() + dataset_dict = asyncio.run(arcee.sender.register_dataset( + arcee.token, arcee.run, arcee.name, arcee.task_key, + body=dataset.__dict__, comment=comment )) + dataset._version = dataset_dict["version"]["version"] + + +def use_dataset(dataset: str, comment: str = None) -> Dataset: + """ + Use dataset + Args: + dataset: the dataset indentifier in key:version format + comment: the usage comment + Returns: Dataset + """ + arcee = Arcee() + dataset_dict = asyncio.run(arcee.sender.use_dataset( + arcee.token, arcee.run, dataset, comment=comment)) + return Dataset.from_response(dataset_dict) def _send_console(): diff --git a/kiroframe_arcee/modules/__init__.py b/kiroframe_arcee/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/kiroframe_arcee/modules/dataset.py b/kiroframe_arcee/modules/dataset.py new file mode 100644 index 0000000..9893fb3 --- /dev/null +++ b/kiroframe_arcee/modules/dataset.py @@ -0,0 +1,140 @@ +import os +import asyncio +import threading +from typing import List, Dict +from kiroframe_arcee.modules.providers import local_file, amazon + +LOCAL_PREFIX = 'file://' +S3_PREFIX = 's3://' +BASE_PATH = 'kiroframe/datasets/%s/' + + +class DatasetThread(threading.Thread): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.exception = None + + def run(self): + try: + if self._target: + self._target(*self._args, **self._kwargs) + except Exception as e: + self.exception = e + + +class Dataset(object): + __slots__ = ('key', 'name', 'description', 'labels', 'meta', + 'timespan_from', 'timespan_to', 'aliases', + '_tasks', '_files', '_version') + + def __init__(self, key: str, name: str = None, description: str = None, + labels: List[str] = None, meta: Dict = None, + timespan_from: int = None, timespan_to: int = None, + aliases: List[str] = None): + self._tasks: List = [] + self._files: Dict = {} + self._version: int = None + self.key: str = key + self.name: str = name + self.description: str = description + self.labels: List[str] = labels or [] + self.meta: Dict = meta or {} + self.timespan_from: int = timespan_from + self.timespan_to: int = timespan_to + self.aliases: List[str] = aliases or [] + + @classmethod + def from_response(cls, response): + version = response.pop('version', {}) + files = version.pop('files', []) + response.update(version) + obj = cls(**{ + k: response.get(k) for k in cls.__slots__ if k in response + }) + obj._version = version['version'] + if files: + obj._files = {f['path']: { + 'path': f['path'], + 'size': f['size'], + 'digest': f['digest'] + } for f in files} + return obj + + @property + def __dict__(self): + res = dict() + for k in self.__slots__: + if k.startswith('_'): + continue + value = getattr(self, k) + if value: + res[k] = getattr(self, k) + res['files'] = list(self._files.values()) + return res + + def _get_provider(self, path): + if path.startswith(LOCAL_PREFIX): + return local_file, path.strip(LOCAL_PREFIX) + elif path.startswith(S3_PREFIX): + return amazon, path + else: + raise TypeError('Unhandled path type') + + def _add_file(self, path): + provider, local_path = self._get_provider(path) + digest, size = asyncio.run(provider.get_file_info(local_path)) + self._files[path] = { + 'path': path, + 'size': size, + 'digest': digest + } + + def add_file(self, path): + if path in self._files and self._version is None: + return + self._version = None + self._files[path] = None + thr = DatasetThread(target=self._add_file, args=(path, )) + thr.start() + self._tasks.append(thr) + + def remove_file(self, path): + if path in self._files: + del self._files[path] + + def wait_ready(self): + for task in self._tasks: + task.join() + if task.exception: + raise task.exception + + def download(self, overwrite=True) -> dict: + download_map = dict() + if self._version is None: + raise TypeError('Dataset is not logged') + name = f'{self.key}:V{self._version}' + print('Downloading %s' % name) + for path, file in self._files.items(): + destination = BASE_PATH % name + file_name = path.split('/')[-1] + download_path = destination + file_name + download_map[path] = download_path + if not overwrite and os.path.isfile(download_path): + continue + thr = DatasetThread( + target=self._download, args=(file, destination, file_name)) + thr.start() + self._tasks.append(thr) + self.wait_ready() + print('Download completed: %s' % name) + return download_map + + def _download(self, file, destination, file_name): + digest = file['digest'] + path = file['path'] + provider, local_path = self._get_provider(path) + asyncio.run( + provider.download( + local_path, digest, destination, file_name + ) + ) diff --git a/kiroframe_arcee/modules/providers/__init__.py b/kiroframe_arcee/modules/providers/__init__.py new file mode 100644 index 0000000..72d4929 --- /dev/null +++ b/kiroframe_arcee/modules/providers/__init__.py @@ -0,0 +1,4 @@ +from kiroframe_arcee.modules.providers import local_file +from kiroframe_arcee.modules.providers import amazon + +__all__ = ['local_file', 'amazon'] diff --git a/kiroframe_arcee/modules/providers/amazon.py b/kiroframe_arcee/modules/providers/amazon.py new file mode 100644 index 0000000..7105562 --- /dev/null +++ b/kiroframe_arcee/modules/providers/amazon.py @@ -0,0 +1,44 @@ +import os +import aioboto3 +import aiofiles +from urllib.parse import urlparse + + +async def get_file_info(path): + bucket, key = await _parse_uri(path) + session = aioboto3.Session() + async with session.resource("s3") as s3: + res = await s3.Object(bucket, key) + etag = await res.e_tag + size = await res.content_length + return etag.strip('"'), size + + +async def _parse_uri(path): + url = urlparse(path) + bucket = url.netloc + key = url.path[1:] + return bucket, key + + +async def download(path, digest, dest_path, file_name): + bucket, key = await _parse_uri(path) + session = aioboto3.Session() + async with session.resource("s3") as s3: + s3_object = await s3.Object(bucket, key) + etag = await s3_object.e_tag + if etag.strip('"') != digest: + raise ValueError( + 'Cannot download dataset file %s. Source file has been ' + 'changed' % path) + os.makedirs(dest_path, exist_ok=True) + async with aiofiles.open(dest_path + file_name, 'wb') as fp: + await s3_object.download_fileobj(fp) + + +async def main(bucket, key): + session = aioboto3.Session() + async with session.resource("s3") as s3: + bucket = await s3.Object(bucket, key) + async with aiofiles.open('TEST', 'wb') as fp: + await bucket.download_fileobj(fp) diff --git a/kiroframe_arcee/modules/providers/local_file.py b/kiroframe_arcee/modules/providers/local_file.py new file mode 100644 index 0000000..0e60d8a --- /dev/null +++ b/kiroframe_arcee/modules/providers/local_file.py @@ -0,0 +1,52 @@ +import os +import sys +import hashlib +import shutil +import mmap +import aiofiles + +_KB: int = 1_024 +_CHUNKSIZE: int = 128 * _KB + + +async def get_file_info(path): + digest = await _get_md5(path) + size = await _get_size(path) + return digest, size + + +async def _get_md5(path): + if sys.version_info >= (3, 9): + md_5_hash = hashlib.md5(usedforsecurity=False) + else: + md_5_hash = hashlib.md5() + async with aiofiles.open(path, "rb") as f: + try: + with mmap.mmap(f.fileno(), length=0, + access=mmap.ACCESS_READ) as mview: + md_5_hash.update(mview) + except OSError: + chunk = f.read(_CHUNKSIZE) + while chunk: + md_5_hash.update(chunk) + chunk = f.read(_CHUNKSIZE) + except ValueError: + pass + return md_5_hash.hexdigest() + + +async def _get_size(path): + st = os.stat(path) + return st.st_size + + +async def download(path, digest, dest_path, file_name): + if not os.path.exists(path): + raise ValueError('Failed to find file path %s' % path) + md5 = await _get_md5(path) + if md5 != digest: + raise ValueError( + 'Cannot download dataset file %s. Source file has been changed' % + path) + os.makedirs(dest_path, exist_ok=True) + shutil.copy(path, dest_path + file_name) diff --git a/kiroframe_arcee/sender/sender.py b/kiroframe_arcee/sender/sender.py index 3a51282..130a008 100644 --- a/kiroframe_arcee/sender/sender.py +++ b/kiroframe_arcee/sender/sender.py @@ -138,26 +138,23 @@ def generate_description(task_key, run_name, run_id): return f"Discovered in training {task_key} - {run_name}({run_id})" @check_shutdown_flag_set - async def register_dataset(self, token, run_id, run_name, task_key, path, - dataset_name=None, description=None, - labels=None): + async def register_dataset(self, token, run_id, run_name, task_key, body, + comment=None): uri = f"{self.endpoint_url}/run/{run_id}/dataset_register" headers = {"x-api-key": token, "Content-Type": "application/json"} + if 'description' not in body: + body['description'] = self.generate_description( + task_key, run_name, run_id) + if comment: + body['comment'] = comment + return await self.send_post_request(uri, headers, body) - if dataset_name is None: - dataset_name = path - if not description: - description = self.generate_description(task_key, run_name, run_id) - if labels is not None and not isinstance(labels, list): - labels = [labels] - - data = { - "path": path, - "name": dataset_name, - "description": description, - "labels": labels or [] - } - await self.send_post_request(uri, headers, data) + @check_shutdown_flag_set + async def use_dataset(self, token, run_id, dataset: str, comment=None): + uri = f"{self.endpoint_url}/run/{run_id}/dataset_use" + headers = {"x-api-key": token, "Content-Type": "application/json"} + return await self.send_post_request( + uri, headers, {"dataset": dataset, "comment": comment}) @check_shutdown_flag_set async def add_hyperparams(self, run_id, token, hyperparams): diff --git a/setup.cfg b/setup.cfg index cde40c0..53f3ba3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ # setup.cfg [metadata] name = kiroframe_arcee -version = 0.1.49 +version = 0.1.50 author = Hystax description = ML profiling tool for Kiroframe long_description = file: README.md @@ -15,8 +15,9 @@ keywords = arcee, ml, kiroframe, finops, mlops python_requires = >=3.8, <4 install_requires = aiohttp==3.10.11 - aiofiles==22.1.0 - psutil==5.9.2 + aiofiles==23.2.1 + psutil==7.0.0 + aioboto3==14.1.0 packages = find: package_dir = =.