Skip to content

Commit

Permalink
Merge pull request #3 from Artur-Galstyan/superfast
Browse files Browse the repository at this point in the history
Superfast
  • Loading branch information
Artur-Galstyan authored Mar 2, 2024
2 parents e12e900 + 082525c commit 8b70097
Show file tree
Hide file tree
Showing 15 changed files with 238 additions and 284 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/nox.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: nox

on: [push, pull_request]

jobs:
nox:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Set up Kaggle API credentials
run: |
mkdir -p $HOME/.kaggle
echo '${{ secrets.KAGGLE_TOKEN }}' > $HOME/.kaggle/kaggle.json
chmod 600 $HOME/.kaggle/kaggle.json
- name: Install nox
run: pip install nox
- name: Run nox
run: nox
17 changes: 17 additions & 0 deletions .github/workflows/pre_commit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: Pre-commit

on: [push, pull_request]

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install everything
run: pip install -e .
- name: Run pre-commit
run: python3 -m pre_commit run --all-files

2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

testing_ground.py
testing_ground/
93 changes: 61 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
# jaxonloader
# Jaxonloader

A dataloader, but for JAX.

The idea of this package is to have a DataLoader similar to the PyTorch one. To ensure that you don't have to learn anything new to use this package, the same API is chosen here (PyTorch's API actually a very good).

Unfortunately, this also means that this package does _not_ follow the functional programming paradigm, because neither does the PyTorch DataLoader API. While in that regard this DataLoader is not _functional_ per se, it still allows for reproducability since you provide a random key to shuffle the data (if you want to).

At the moment, this package is not yet a 1:1 mapping from PyTorch's DataLoader, but one day, it will be! \**holding up arm and clenching fist\**
A blazingly fast dataloader for JAX that no one asked for, but here it is anyway.

## Installation

Expand All @@ -16,64 +10,99 @@ Install this package using pip like so:
pip install jaxonloader
```

## Usage

Pretty much exactly as you would use PyTorch's DataLoader. Create a dataset class by inheriting from the `jaxonloader` dataset and implement the `__len__` and `__getitem__` functions. Then simply pass that to the DataLoader class as argument.
## Quickstart

On the other hand, you can also use some of the provided datasets, such as the MNIST dataset.
This package differs significantly from the PyTorch `DataLoader` class! In JAX,
there is no internal state, which means we have to keep track of it ourselves. Here's
a minimum example to setup MNIST:

```python

import jax

from jaxonloader import get_mnist
from jaxonloader.dataloader import DataLoader
from jaxonloader import make
key = jax.random.PRNGKey(0)

train, test = get_mnist()
# these are JaxonDatasets

train_loader = DataLoader(
train_loader, index = make(
train,
batch_size=4,
shuffle=False,
drop_last=True,
key=key,
jit=True
)
x = next(iter(train_loader))
print(x[0].shape) # (4, 784)
print(x[1].shape) # (4,)
train_loader = jax.jit(train_loader)
while x:= train_loader(index):
data, index, done = x
processed_data = process_data(data)
if done:
break

```

## Philosophy

The `jaxonloader` package is designed to be as lightweight as possible. In fact, it's
only a very thin wrapper around JAX arrays! Under the hood, it's using
the [Equinox library](https://github.com/Patrick-Kidger/equinox) to handle the
stateful nature of the dataloader. Since the dataloader object is just a `eqx.Module`, it
can be JITted and can be used in other JAX transformations as well (although, I haven't tested this).

## Label & Target Handling

Due to it's lightweight nature, this package - as of now - doesn't perform any kinds of transformations. This means that you will have to transform your data first and then pass them to the dataloader.
This also goes for post-processing the data.

While in PyTorch, you would do something like this:

```python

for x, y in train_dataloader:
# do something with x and y

```

## Performing Transformations
In Jaxonloader, we don't split the row of the dataset into `x` and `y` and instead
simply return the whole row. This means that you will have to do the splitting (i.e. data post-processing) yourself.

As of now, transformations are not supported :(
```python
# MNIST example
while x:= train_loader(index):
data, index, done = x
print(data.shape) # (4, 785)
x, y = data[:, :-1], data[:, -1] # split the data into x and y

# do something with x and y
```

But - since you can get a dataset from a `DataFrame` - you can first
transform your data and then pass it to the `from_dataframe` function.
## Roadmap

It's not ideal, but it works for now.
The goal is to keep this package as lightweight as possible, while also providing as
many datasets as possible. The next steps are to gather as many datasets as possible
and to provide a simple API to load them.

---
---

## Other backends

Other backends are not supported and are not planned to be supported. There is already
a very good dataloader for PyTorch, and with all the support PyTorch has, it's not
Other backends are not supported and are not planned to be supported. There is already
a very good dataloader for PyTorch, and with all the support PyTorch has, it's not
needed to litter the world with yet another PyTorch dataloader. The same goes for TensorFlow as well.

If you really need one, which supports all backends, check out
If you really need one, which supports all backends, check out

[jax-dataloader](https://github.com/BirkhoffG/jax-dataloader)

which does pretty much the same thing as this package, but for all backends.

## Then why does this package exist?
## Then why does this package exist?

For one, I just like building things and don't really care if it's needed or not. Secondly,
I don't care about other backends (as they are already very well supported) and only want to
I don't care about other backends (as they are already very well supported) and only want to
focus on JAX and I needed a lightweight, easy-to-handle package, which loads data in JAX.

So if you're like me and just need a simple dataloader for JAX, this package is for you.
If you need a dataloader for all backends, check out the other package from the link above.
Also, the PyTorch dataloader is slow! To iterate over the MNIST training set, it takes
on a MacBook M1 Pro around 2.83 seconds. Unjitted, the JAX dataloader takes 1.5 seconds and
when jitted, it's around 0.09 seconds! This makes it around 31 times faster than the PyTorch dataloader.
3 changes: 1 addition & 2 deletions jaxonloader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from jaxonloader._datasets import * # noqa
from jaxonloader.dataset import Dataset # noqa
from jaxonloader.dataloader import DataLoader # noqa
from jaxonloader.dataloader import JaxonDataLoader, make # noqa
85 changes: 29 additions & 56 deletions jaxonloader/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,18 @@
from jaxtyping import Array
from loguru import logger

from jaxonloader.dataset import Dataset, StandardDataset
from jaxonloader.dataset import JaxonDataset
from jaxonloader.utils import jaxonloader_cache, JAXONLOADER_PATH


@jaxonloader_cache(dataset_name="huggingface")
def get_huggingface_dataset():
raise NotImplementedError("get_huggingface_dataset is not implemented yet.")


@jaxonloader_cache(dataset_name="kaggle")
def get_kaggle_dataset(
dataset_name: str,
force_redownload: bool = False,
*,
kaggle_json_path: str | None = None,
combine_columns_to_row: bool = False,
) -> list[Dataset]:
) -> list[JaxonDataset]:
"""
Get a dataset from Kaggle. You need to have the Kaggle
API token in your home directory. Furthermore,
Expand All @@ -44,7 +39,7 @@ def get_kaggle_dataset(
returned as a tuple.
Returns:
A list of datasets. Each dataset is of class StandardDataset.
A list of datasets.
Raises:
FileNotFoundError: If the dataset is not found in Kaggle.
Expand Down Expand Up @@ -156,7 +151,7 @@ def get_kaggle_dataset_dataframes(


@jaxonloader_cache(dataset_name="mnist")
def get_mnist() -> tuple[Dataset, Dataset]:
def get_mnist() -> tuple[JaxonDataset, JaxonDataset]:
MNIST_TRAIN_URL = (
"https://omnisium.eu-central-1.linodeobjects.com/mnist/mnist_train.csv.zip"
)
Expand Down Expand Up @@ -189,14 +184,11 @@ def get_mnist() -> tuple[Dataset, Dataset]:
train_df = pl.read_csv(data_path / "mnist_train.csv")
test_df = pl.read_csv(data_path / "mnist_test.csv")

x_train = jnp.array(train_df.drop("label").to_numpy())
y_train = jnp.array(train_df["label"].to_numpy())
x_train = jnp.array(train_df.to_numpy())
x_test = jnp.array(test_df.to_numpy())

x_test = jnp.array(test_df.drop("label").to_numpy())
y_test = jnp.array(test_df["label"].to_numpy())

train_dataset = StandardDataset(x_train, y_train)
test_dataset = StandardDataset(x_test, y_test)
train_dataset = JaxonDataset(x_train)
test_dataset = JaxonDataset(x_test)

return train_dataset, test_dataset

Expand All @@ -216,7 +208,9 @@ def get_fashion_mnist():
@jaxonloader_cache(dataset_name="tinyshakespeare")
def get_tiny_shakespeare(
block_size: int = 8, train_ratio: float = 0.8
) -> tuple[Dataset, Dataset, int, Callable[[str], Array], Callable[[Array], str]]:
) -> tuple[
JaxonDataset, JaxonDataset, int, Callable[[str], Array], Callable[[Array], str]
]:
"""
Get the tiny shakespeare dataset from Andrej Karpathy's char-rnn repository.
Expand All @@ -232,7 +226,6 @@ def get_tiny_shakespeare(
- encoder: A function that encodes a string into a sequence of integers.
- decoder: A function that decodes a sequence of integers into a string.
Example:
```python
from jaxonloader import get_tiny_shakespeare
Expand All @@ -241,31 +234,6 @@ def get_tiny_shakespeare(
```
"""

class MiniShakesPeare(Dataset):
def __init__(self, data: Array, block_size: int = block_size) -> None:
self.block_size = block_size
self.data = data

def __len__(self):
return len(self.data)

def __getitem__(self, index: int):
if index == -1:
index = len(self.data) - 1
x = self.data[index : index + self.block_size]
y = self.data[index + 1 : index + self.block_size + 1]

if index + self.block_size + 1 > len(self.data):
diff = index + self.block_size + 1 - len(self.data)

to_add_on_x = diff - 1
to_add_on_y = diff

x = jnp.concatenate((x, self.data[:to_add_on_x]))
y = jnp.concatenate((y, self.data[:to_add_on_y]))

return x, y

def get_text():
data_path = pathlib.Path(JAXONLOADER_PATH) / "tinyshakespeare/"
if not os.path.exists(data_path / "input.txt"):
Expand Down Expand Up @@ -297,20 +265,28 @@ def decode(latent: Array) -> str:
data = jnp.array(encode(text))
n = int(train_ratio * len(data))

train_data = data[:n]
test_data = data[n:]
x_train = data[:n]
remainder = len(x_train) % block_size
x_train = x_train[:-remainder].reshape(-1, block_size)
y_train = jnp.roll(x_train, -1)
train_data = jnp.concatenate(arrays=(x_train, y_train), axis=1)
train_dataset = JaxonDataset(train_data)

train_dataset = MiniShakesPeare(train_data, block_size=block_size)
test_dataset = MiniShakesPeare(test_data, block_size=block_size)
x_test = data[n:]
remainder = len(x_test) % block_size
x_test = x_test[:-remainder].reshape(-1, block_size)
y_test = jnp.roll(x_test, -1)
test_data = jnp.concatenate(arrays=(x_test, y_test), axis=1)
test_dataset = JaxonDataset(test_data)

return train_dataset, test_dataset, vocab_size, encoder, decoder


def from_dataframes(
*dataframes: pl.DataFrame | pd.DataFrame, combine_columns_to_row: bool = False
) -> list[Dataset]:
) -> list[JaxonDataset]:
"""
Convert a list of polars.DataFrame (or pandas.DataFrame) to a list of Dataset.
Convert a list of polars.DataFrame (or pandas.DataFrame) to a list of JaxonDataset.
Args:
dataframes: A list of polars.DataFrame (or pandas.DataFrame).
Expand All @@ -319,16 +295,13 @@ def from_dataframes(
returned as a tuple. Keyword-only argument.
Returns:
A list of Dataset.
A list of JaxonDataset.
"""
datasets: list[Dataset] = []
datasets: list[JaxonDataset] = []
for df in dataframes:
dataframe: pl.DataFrame = (
pl.from_pandas(df) if isinstance(df, pd.DataFrame) else df
)
columns = [jnp.array(df[col].to_numpy()) for col in dataframe.columns]
datasets.append(
StandardDataset(*columns, combine_columns_to_row=combine_columns_to_row)
)

data = jnp.array(dataframe.to_numpy())
datasets.append(JaxonDataset(data))
return datasets
Loading

0 comments on commit 8b70097

Please sign in to comment.