Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 79 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,24 +109,91 @@ Example:
arcee.stage("preparing")
```

## Logging datasets
To log a dataset, use the `dataset` method with the following parameters:
- path (str, required): the dataset path.
## Datasets
### Logging
Logging a dataset allows you to create a dataset or a new version of
the dataset if the dataset has already been created, but has been changed.
To create a dataset, use the `Dataset` class with the following parameters:

Dataset parameters:
- key (str, required): the unique dataset key.
- name (str, optional): the dataset name.
- description (str, optional): the dataset description.
- labels (list, optional): the dataset labels.

Version parameters:
- aliases (list, optional): the list of aliases for this version.
- meta (dict, optional): the dataset version meta.
- timespan_from (int, optional): the dataset version timespan from.
- timespan_to (int, optional): the dataset version timespan to.
```sh
dataset = arcee.Dataset(key='YOUR-DATASET-KEY',
name='YOUR-DATASET-NAME',
description="YOUR-DATASET-DESCRIPTION",
...
)
dataset.labels = ["YOUR-DATASET-LABEL-1", "YOUR-DATASET-LABEL-2"]
dataset.aliases = ['YOUR-VERSION-ALIAS']
```
To log a dataset, use the `log_dataset` method with the following parameters:
- dataset (Dataset, required): the dataset object.
- comment (str, optional): the usage comment.
```sh
arcee.log_dataset(dataset=dataset, comment='LOGGING_COMMENT')
```

### Using
To use a dataset, use the `use_dataset` method with dataset `key:version`.
Parameters:
- dataset (str, required): the dataset indentifier in key:version format.
- comment (str, optional): the usage comment.
```sh
dataset = arcee.use_dataset(
dataset='YOUR-DATASET-KEY:YOUR-DATASET-VERSION-OR-ALIAS')
```

### Adding files and downloading
You can add or remove files from dataset and download it as well.
Supported file paths:
- `file://` - the local files.
- `s3://` - the amazon S3 files.

adding / removing files

local:
```sh
dataset.remove_file(path='file://LOCAL_PATH_TO_FILE_1')
dataset.add_file(path='file://LOCAL_PATH_TO_FILE_2')
arcee.log_dataset(dataset=dataset)
```
s3:
```sh
arcee.dataset(path="YOUR-DATASET-PATH",
name="YOUR-DATASET-NAME",
description="YOUR-DATASET-DESCRIPTION",
labels=["YOUR-DATASET-LABEL-1", "YOUR-DATASET-LABEL-2"])
os.environ['AWS_ACCESS_KEY_ID'] = 'AWS_ACCESS_KEY_ID'
os.environ['AWS_SECRET_ACCESS_KEY'] = 'AWS_SECRET_ACCESS_KEY'
dataset.remove_file(path='s3://BUCKET/PATH_1')
dataset.add_file(path='s3://BUCKET/PATH_2')
arcee.log_dataset(dataset=dataset)
```
downloading:
Parameters:
- overwrite (bool, optional): overwrite an existing dataset or skip
downloading if it already exists.
```sh
dataset.download(overwrite=True)
```
Example:
```sh
arcee.dataset("https://s3/ml-bucket/datasets/training_dataset.csv",
name="Training dataset",
description="Training dataset (100k rows)",
labels=["training", "100k"])
# use version v0, v1 etc, or any version alias: my_dataset:latest
dataset = arcee.use_dataset(dataset='my_dataset:V0')
path_map = dataset.download()
for local_path in path_map.values():
with open(local_path, 'r'):
# read downloaded file

new_dataset = arcee.Dataset('new_dataset')
new_dataset.add_file(path='s3://ml-bucket/datasets/training_dataset.csv')
arcee.log_dataset(dataset=new_dataset)
new_dataset.download()
```

## Creating models
Expand Down Expand Up @@ -221,4 +288,4 @@ arcee.finish()
To fail a run, use the `error` method.
```sh
arcee.error()
```
```
21 changes: 21 additions & 0 deletions examples/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import kiroframe_arcee as arcee

with arcee.init("test", "simple"):
dataset = arcee.Dataset(key='test_dataset', description='test dataset')

# adding file
dataset.add_file(path=__file__)

# log new dataset_version (with file)
arcee.log_dataset(dataset)
print('Actual dataset version: ', dataset._version)

# downloading
dataset.download()

# remove file from dataset
dataset.remove_file(path=__file__)

# log new dataset_version (without files)
arcee.log_dataset(dataset)
print('Actual dataset version: ', dataset._version)
6 changes: 3 additions & 3 deletions kiroframe_arcee/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# flake8: noqa: F401
from .arcee import (init, send, tag, milestone, info, finish, error, stage,
dataset, hyperparam, model, model_version,
model_version_alias, model_version_tag, artifact,
artifact_tag)
hyperparam, model, model_version, model_version_alias,
model_version_tag, artifact, artifact_tag, Dataset,
log_dataset, use_dataset)
37 changes: 22 additions & 15 deletions kiroframe_arcee/arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
acquire_console, release_console)
from kiroframe_arcee.name_generator import NameGenerator
from kiroframe_arcee.utils import single
from kiroframe_arcee.modules.dataset import Dataset


class Job(threading.Thread):
Expand Down Expand Up @@ -49,7 +50,6 @@ def __init__(
self._tags = dict()
self._name = None
self._hyperparams = dict()
self._dataset = None
self._model = None
self._model_version = None
self._model_version_tags = dict()
Expand Down Expand Up @@ -90,14 +90,6 @@ def hyperparams(self, value):
k, v = value
self._hyperparams.update({k: v})

@property
def dataset(self):
return self._dataset

@dataset.setter
def dataset(self, value):
self._dataset = value

def __enter__(self):
return self

Expand Down Expand Up @@ -225,14 +217,29 @@ def stage(name):
asyncio.run(arcee.sender.create_stage(arcee.run, arcee.token, name))


def dataset(path, name=None, description=None, labels=None):
def log_dataset(dataset: Dataset, comment: str = None):
arcee = Arcee()
if arcee.dataset is None:
arcee.dataset = path
asyncio.run(arcee.sender.register_dataset(
arcee.token, arcee.run, arcee.name, arcee.task_key, path, name,
description, labels
if dataset:
dataset.wait_ready()
dataset_dict = asyncio.run(arcee.sender.register_dataset(
arcee.token, arcee.run, arcee.name, arcee.task_key,
body=dataset.__dict__, comment=comment
))
dataset._version = dataset_dict["version"]["version"]


def use_dataset(dataset: str, comment: str = None) -> Dataset:
"""
Use dataset
Args:
dataset: the dataset indentifier in key:version format
comment: the usage comment
Returns: Dataset
"""
arcee = Arcee()
dataset_dict = asyncio.run(arcee.sender.use_dataset(
arcee.token, arcee.run, dataset, comment=comment))
return Dataset.from_response(dataset_dict)


def _send_console():
Expand Down
Empty file.
140 changes: 140 additions & 0 deletions kiroframe_arcee/modules/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import os
import asyncio
import threading
from typing import List, Dict
from kiroframe_arcee.modules.providers import local_file, amazon

LOCAL_PREFIX = 'file://'
S3_PREFIX = 's3://'
BASE_PATH = 'kiroframe/datasets/%s/'


class DatasetThread(threading.Thread):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.exception = None

def run(self):
try:
if self._target:
self._target(*self._args, **self._kwargs)
except Exception as e:
self.exception = e


class Dataset(object):
__slots__ = ('key', 'name', 'description', 'labels', 'meta',
'timespan_from', 'timespan_to', 'aliases',
'_tasks', '_files', '_version')

def __init__(self, key: str, name: str = None, description: str = None,
labels: List[str] = None, meta: Dict = None,
timespan_from: int = None, timespan_to: int = None,
aliases: List[str] = None):
self._tasks: List = []
self._files: Dict = {}
self._version: int = None
self.key: str = key
self.name: str = name
self.description: str = description
self.labels: List[str] = labels or []
self.meta: Dict = meta or {}
self.timespan_from: int = timespan_from
self.timespan_to: int = timespan_to
self.aliases: List[str] = aliases or []

@classmethod
def from_response(cls, response):
version = response.pop('version', {})
files = version.pop('files', [])
response.update(version)
obj = cls(**{
k: response.get(k) for k in cls.__slots__ if k in response
})
obj._version = version['version']
if files:
obj._files = {f['path']: {
'path': f['path'],
'size': f['size'],
'digest': f['digest']
} for f in files}
return obj

@property
def __dict__(self):
res = dict()
for k in self.__slots__:
if k.startswith('_'):
continue
value = getattr(self, k)
if value:
res[k] = getattr(self, k)
res['files'] = list(self._files.values())
return res

def _get_provider(self, path):
if path.startswith(LOCAL_PREFIX):
return local_file, path.strip(LOCAL_PREFIX)
elif path.startswith(S3_PREFIX):
return amazon, path
else:
raise TypeError('Unhandled path type')

def _add_file(self, path):
provider, local_path = self._get_provider(path)
digest, size = asyncio.run(provider.get_file_info(local_path))
self._files[path] = {
'path': path,
'size': size,
'digest': digest
}

def add_file(self, path):
if path in self._files and self._version is None:
return
self._version = None
self._files[path] = None
thr = DatasetThread(target=self._add_file, args=(path, ))
thr.start()
self._tasks.append(thr)

def remove_file(self, path):
if path in self._files:
del self._files[path]

def wait_ready(self):
for task in self._tasks:
task.join()
if task.exception:
raise task.exception

def download(self, overwrite=True) -> dict:
download_map = dict()
if self._version is None:
raise TypeError('Dataset is not logged')
name = f'{self.key}:V{self._version}'
print('Downloading %s' % name)
for path, file in self._files.items():
destination = BASE_PATH % name
file_name = path.split('/')[-1]
download_path = destination + file_name
download_map[path] = download_path
if not overwrite and os.path.isfile(download_path):
continue
thr = DatasetThread(
target=self._download, args=(file, destination, file_name))
thr.start()
self._tasks.append(thr)
self.wait_ready()
print('Download completed: %s' % name)
return download_map

def _download(self, file, destination, file_name):
digest = file['digest']
path = file['path']
provider, local_path = self._get_provider(path)
asyncio.run(
provider.download(
local_path, digest, destination, file_name
)
)
4 changes: 4 additions & 0 deletions kiroframe_arcee/modules/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from kiroframe_arcee.modules.providers import local_file
from kiroframe_arcee.modules.providers import amazon

__all__ = ['local_file', 'amazon']
44 changes: 44 additions & 0 deletions kiroframe_arcee/modules/providers/amazon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
import aioboto3
import aiofiles
from urllib.parse import urlparse


async def get_file_info(path):
bucket, key = await _parse_uri(path)
session = aioboto3.Session()
async with session.resource("s3") as s3:
res = await s3.Object(bucket, key)
etag = await res.e_tag
size = await res.content_length
return etag.strip('"'), size


async def _parse_uri(path):
url = urlparse(path)
bucket = url.netloc
key = url.path[1:]
return bucket, key


async def download(path, digest, dest_path, file_name):
bucket, key = await _parse_uri(path)
session = aioboto3.Session()
async with session.resource("s3") as s3:
s3_object = await s3.Object(bucket, key)
etag = await s3_object.e_tag
if etag.strip('"') != digest:
raise ValueError(
'Cannot download dataset file %s. Source file has been '
'changed' % path)
os.makedirs(dest_path, exist_ok=True)
async with aiofiles.open(dest_path + file_name, 'wb') as fp:
await s3_object.download_fileobj(fp)


async def main(bucket, key):
session = aioboto3.Session()
async with session.resource("s3") as s3:
bucket = await s3.Object(bucket, key)
async with aiofiles.open('TEST', 'wb') as fp:
await bucket.download_fileobj(fp)
Loading