From 65ba10bc9b8dc2cf32af455469af35acf1de9067 Mon Sep 17 00:00:00 2001 From: Artur Galstyan Date: Thu, 4 Apr 2024 17:52:04 +0200 Subject: [PATCH] added hms dataset --- jaxonloader/datasets/raw.py | 34 +++++++++++++++++++++++++++++++--- jaxonloader/utils.py | 10 ++++++++++ pyproject.toml | 2 +- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/jaxonloader/datasets/raw.py b/jaxonloader/datasets/raw.py index c80b20f..091c11f 100644 --- a/jaxonloader/datasets/raw.py +++ b/jaxonloader/datasets/raw.py @@ -1,12 +1,15 @@ +import os import pathlib import pickle +import zipfile from typing import Any import pandas as pd +from loguru import logger from numpy import ndarray as NDArray from jaxonloader.utils import ( - download_and_extract_zip, + download, jaxonloader_cache, JAXONLOADER_PATH, ) @@ -24,9 +27,34 @@ def get_hms() -> tuple[dict[str, NDArray], dict[str, NDArray], dict[Any, Any]]: - `train_df: dict[Any, Any]`: A dictionary of the training dataframe. """ - data_url = "https://omnisium.eu-central-1.linodeobjects.com/hms/hms.zip" + data_url = "https://omnisium.eu-central-1.linodeobjects.com/hms/" data_path = pathlib.Path(JAXONLOADER_PATH) / "hms" - download_and_extract_zip(data_url, data_path) + expected_files = [ + "eegs.pkl", + "spectrograms.pkl", + "train.csv", + ] + + if not all((data_path / file).exists() for file in expected_files): + parts = [ + "part_aa", + "part_ab", + "part_ac", + "part_ad", + "part_ae", + "part_af", + "part_ag", + ] + for part in parts: + download(data_url + part, data_path) + os.system(f"cat {data_path}/part_* > {data_path}/hms.zip") + with zipfile.ZipFile(data_path / "hms.zip", "r") as zip_ref: + logger.info(f"Extracting the dataset to {data_path}") + zip_ref.extractall(data_path) + logger.info("Cleaning up the dataset") + os.remove(data_path / "hms.zip") + for part in parts: + os.remove(data_path / part) with open(data_path / "eegs.pkl", "rb") as f: eegs = pickle.load(f) diff --git a/jaxonloader/utils.py b/jaxonloader/utils.py index 1244279..b885d01 100644 --- a/jaxonloader/utils.py +++ b/jaxonloader/utils.py @@ -70,6 +70,16 @@ def show_progress(block_num, block_size, total_size): pbar = None +def download(url: str, data_path: pathlib.Path) -> None: + if not os.path.exists(data_path): + os.makedirs(data_path) + file_name = url.split("/")[-1] + logger.info(f"Downloading from {url}") + urllib.request.urlretrieve(url, data_path / file_name, show_progress) + if os.path.exists(data_path / "__MACOSX"): + shutil.rmtree(data_path / "__MACOSX") + + def download_and_extract_zip(url: str, data_path: pathlib.Path) -> None: if os.path.exists(data_path / ".DS_Store"): os.remove(data_path / ".DS_Store") diff --git a/pyproject.toml b/pyproject.toml index 174b43b..beedd87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "jaxonloader" -version = "0.2.7" +version = "0.2.8" description = "A dataloader, but for JAX" readme = "README.md" requires-python ="~=3.10"