Pallas is JAX's kernel language for writing custom operations that run on TPU — what Triton is for GPUs. Kernel languages provide low-level access to hardware, allowing optimizations outside of the compiler's reach.
This repo contains progressive puzzle notebooks that build from Pallas basics towards real open-source kernels. All puzzles run on free Google Colab CPU instances via interpret=True — no TPU needed.
Blog post: vorushin.github.io/blog/pallas-puzzles
SplashAttention — SParse version of fLASH attention — is an efficient implementation of attention on TPUs.
basics: how to write Pallas kernels, up to batched matmuls.
splash_attention: from vanilla softmax to the block-sparse implementation.
Grouped matrix multiplications are the core building blocks of modern MoEs.
basics: same as above.
grouped_matmul: how to split tokens into blocks and multiply them efficiently with expert weights.
Prerequisites: solid JAX/NumPy. No prior Pallas experience required.
The notebooks were created with Claude Code. The project contains guidelines in CLAUDE.md on how to create new notebooks with progressive puzzles — could be a good starting point for creating interactive study materials tailored for your needs.
Click a Colab badge above, or run locally:
pip install jax jaxtyping
jupyter notebook basics.ipynb