Skip to content

Commit

Permalink
added hms dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur-Galstyan committed Apr 4, 2024
1 parent 8a47fd6 commit 65ba10b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
34 changes: 31 additions & 3 deletions jaxonloader/datasets/raw.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import pathlib
import pickle
import zipfile
from typing import Any

import pandas as pd
from loguru import logger
from numpy import ndarray as NDArray

from jaxonloader.utils import (
download_and_extract_zip,
download,
jaxonloader_cache,
JAXONLOADER_PATH,
)
Expand All @@ -24,9 +27,34 @@ def get_hms() -> tuple[dict[str, NDArray], dict[str, NDArray], dict[Any, Any]]:
- `train_df: dict[Any, Any]`: A dictionary of the training dataframe.
"""
data_url = "https://omnisium.eu-central-1.linodeobjects.com/hms/hms.zip"
data_url = "https://omnisium.eu-central-1.linodeobjects.com/hms/"
data_path = pathlib.Path(JAXONLOADER_PATH) / "hms"
download_and_extract_zip(data_url, data_path)
expected_files = [
"eegs.pkl",
"spectrograms.pkl",
"train.csv",
]

if not all((data_path / file).exists() for file in expected_files):
parts = [
"part_aa",
"part_ab",
"part_ac",
"part_ad",
"part_ae",
"part_af",
"part_ag",
]
for part in parts:
download(data_url + part, data_path)
os.system(f"cat {data_path}/part_* > {data_path}/hms.zip")
with zipfile.ZipFile(data_path / "hms.zip", "r") as zip_ref:
logger.info(f"Extracting the dataset to {data_path}")
zip_ref.extractall(data_path)
logger.info("Cleaning up the dataset")
os.remove(data_path / "hms.zip")
for part in parts:
os.remove(data_path / part)

with open(data_path / "eegs.pkl", "rb") as f:
eegs = pickle.load(f)
Expand Down
10 changes: 10 additions & 0 deletions jaxonloader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ def show_progress(block_num, block_size, total_size):
pbar = None


def download(url: str, data_path: pathlib.Path) -> None:
if not os.path.exists(data_path):
os.makedirs(data_path)
file_name = url.split("/")[-1]
logger.info(f"Downloading from {url}")
urllib.request.urlretrieve(url, data_path / file_name, show_progress)
if os.path.exists(data_path / "__MACOSX"):
shutil.rmtree(data_path / "__MACOSX")


def download_and_extract_zip(url: str, data_path: pathlib.Path) -> None:
if os.path.exists(data_path / ".DS_Store"):
os.remove(data_path / ".DS_Store")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "jaxonloader"
version = "0.2.7"
version = "0.2.8"
description = "A dataloader, but for JAX"
readme = "README.md"
requires-python ="~=3.10"
Expand Down

0 comments on commit 65ba10b

Please sign in to comment.