Overview | Installation | Usage | Documentation
jax_dataloader brings pytorch-like dataloader API to jax. It
supports
-
4 datasets to download and pre-process data:
-
3 backends to iteratively load batches:
A minimum jax-dataloader example:
import jax_dataloader as jdl
jdl.manual_seed(1234) # Set the global seed to 1234 for reproducibility
dataloader = jdl.DataLoader(
dataset, # Can be a jdl.Dataset or pytorch or huggingface or tensorflow dataset
backend='jax', # Use 'jax' backend for loading data
batch_size=32, # Batch size
shuffle=True, # Shuffle the dataloader every iteration or not
drop_last=False, # Drop the last batch or not
generator=jdl.Generator() # Control the randomness of this dataloader
)
batch = next(iter(dataloader)) # iterate next batchThe latest jax-dataloader release can directly be installed from PyPI:
pip install jax-dataloaderor install directly from the repository:
pip install git+https://github.com/BirkhoffG/jax-dataloader.gitNote
We keep jax-dataloader’s dependencies minimum, which only install
jax and plum-dispatch (for backend dispatching) when installing.
If you wish to use integration of pytorch,
huggingface datasets, or
tensorflow, we highly recommend
manually install those dependencies.
You can also run pip install jax-dataloader[all] to install
everything (not recommended).
jax_dataloader.core.DataLoader
follows similar API as the pytorch dataloader.
- The
datasetshould be an object of the subclass ofjax_dataloader.core.Datasetortorch.utils.data.Datasetor (the huggingface)datasets.Datasetortf.data.Dataset. - The
backendshould be one of"jax"or"pytorch"or"tensorflow". This argument specifies which backend dataloader to load batches.
Note that not every dataset is compatible with every backend. See the compatibility table below:
jdl.Dataset |
torch_data.Dataset |
tf.data.Dataset |
datasets.Dataset |
|
|---|---|---|---|---|
"jax" |
✅ | ❌ | ❌ | ✅ |
"pytorch" |
✅ | ✅ | ❌ | ✅ |
"tensorflow" |
✅ | ❌ | ✅ | ✅ |
Using ArrayDataset
The jax_dataloader.core.ArrayDataset is an easy way to wrap multiple
jax.numpy.array into one Dataset. For example, we can create an
ArrayDataset
as follows:
# Create features `X` and labels `y`
X = jnp.arange(100).reshape(10, 10)
y = jnp.arange(10)
# Create an `ArrayDataset`
arr_ds = jdl.ArrayDataset(X, y)This arr_ds can be loaded by every backends.
# Create a `DataLoader` from the `ArrayDataset` via jax backend
dataloader = jdl.DataLoader(arr_ds, 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(arr_ds, 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(arr_ds, 'tensorflow', batch_size=5, shuffle=True)The huggingface datasets is a
morden library for downloading, pre-processing, and sharing datasets.
jax_dataloader supports directly passing the huggingface datasets.
from datasets import load_datasetFor example, We load the "squad" dataset from datasets:
hf_ds = load_dataset("squad")Then, we can use jax_dataloader to load batches of hf_ds.
# Create a `DataLoader` from the `datasets.Dataset` via jax backend
dataloader = jdl.DataLoader(hf_ds['train'], 'jax', batch_size=5, shuffle=True)
# Or we can use the pytorch backend
dataloader = jdl.DataLoader(hf_ds['train'], 'pytorch', batch_size=5, shuffle=True)
# Or we can use the tensorflow backend
dataloader = jdl.DataLoader(hf_ds['train'], 'tensorflow', batch_size=5, shuffle=True)The pytorch Dataset and its
ecosystems (e.g.,
torchvision,
torchtext,
torchaudio) supports many
built-in datasets. jax_dataloader supports directly passing the
pytorch Dataset.
Note
Unfortuantely, the pytorch
Dataset can only work with
backend=pytorch. See the belowing example.
from torchvision.datasets import MNIST
import numpy as npWe load the MNIST dataset from torchvision. The ToNumpy object
transforms images to numpy.array.
pt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float), train=False)This pt_ds can only be loaded via "pytorch" dataloaders.
dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=5, shuffle=True)jax_dataloader supports directly passing the tensorflow
datasets.
import tensorflow_datasets as tfds
import tensorflow as tfFor instance, we can load the MNIST dataset from tensorflow_datasets
tf_ds = tfds.load('mnist', split='test', as_supervised=True)and use jax_dataloader for iterating the dataset.
dataloader = jdl.DataLoader(tf_ds, 'tensorflow', batch_size=5, shuffle=True)