diff --git a/docs/examples/sparse_finch.ipynb b/docs/examples/sparse_finch.ipynb
new file mode 100644
index 00000000..cbf942b4
--- /dev/null
+++ b/docs/examples/sparse_finch.ipynb
@@ -0,0 +1,543 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Finch backend for `sparse`\n",
+ "\n",
+ "\n",
+ "
\n",
+ " to download and run."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#!pip install 'sparse[finch]==0.16.0a9' scipy\n",
+ "#!export SPARSE_BACKEND=Finch\n",
+ "\n",
+ "# let's make sure we're using Finch backend\n",
+ "import os\n",
+ "\n",
+ "os.environ[\"SPARSE_BACKEND\"] = \"Finch\"\n",
+ "CI_MODE = bool(int(os.getenv(\"CI_MODE\", default=\"0\")))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import importlib\n",
+ "import time\n",
+ "\n",
+ "import sparse\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import networkx as nx\n",
+ "\n",
+ "import numpy as np\n",
+ "import scipy.sparse as sps\n",
+ "import scipy.sparse.linalg as splin"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tns = sparse.asarray(np.zeros((10, 10))) # offers a no-copy constructor for NumPy as scipy.sparse inputs\n",
+ "\n",
+ "s1 = sparse.random((100, 10), density=0.01) # creates random COO tensor\n",
+ "s2 = sparse.random((100, 100, 10), density=0.01)\n",
+ "s2 = sparse.asarray(s2, format=\"csf\") # can be used to rewrite tensor to a new format\n",
+ "\n",
+ "result = sparse.tensordot(s1, s2, axes=([0, 1], [0, 2]))\n",
+ "\n",
+ "total = sparse.sum(result * result)\n",
+ "print(total)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Example: least squares - closed form"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "y = sparse.random((100, 1), density=0.08)\n",
+ "X = sparse.random((100, 5), density=0.08)\n",
+ "X = sparse.asarray(X, format=\"csc\")\n",
+ "X_lazy = sparse.lazy(X)\n",
+ "\n",
+ "X_X = sparse.compute(sparse.permute_dims(X_lazy, (1, 0)) @ X_lazy)\n",
+ "\n",
+ "X_X = sparse.asarray(X_X, format=\"csc\") # move back from dense to CSC format\n",
+ "\n",
+ "inverted = splin.inv(X_X) # dispatching to scipy.sparse.sparray\n",
+ "\n",
+ "b_hat = (inverted @ sparse.permute_dims(X, (1, 0))) @ y\n",
+ "\n",
+ "print(b_hat.todense())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Benchmark plots"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ITERS = 1\n",
+ "rng = np.random.default_rng(0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.style.use(\"seaborn-v0_8\")\n",
+ "plt.rcParams[\"figure.dpi\"] = 400\n",
+ "plt.rcParams[\"figure.figsize\"] = [8, 4]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def benchmark(func, info, args) -> float:\n",
+ " start = time.time()\n",
+ " for _ in range(ITERS):\n",
+ " func(*args)\n",
+ " elapsed = time.time() - start\n",
+ " return elapsed / ITERS"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## MTTKRP"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"MTTKRP Example:\\n\")\n",
+ "\n",
+ "os.environ[sparse._ENV_VAR_NAME] = \"Numba\"\n",
+ "importlib.reload(sparse)\n",
+ "\n",
+ "configs = [\n",
+ " {\"I_\": 100, \"J_\": 25, \"K_\": 100, \"L_\": 10, \"DENSITY\": 0.001},\n",
+ " {\"I_\": 100, \"J_\": 25, \"K_\": 100, \"L_\": 100, \"DENSITY\": 0.001},\n",
+ " {\"I_\": 1000, \"J_\": 25, \"K_\": 100, \"L_\": 100, \"DENSITY\": 0.001},\n",
+ " {\"I_\": 1000, \"J_\": 25, \"K_\": 1000, \"L_\": 100, \"DENSITY\": 0.001},\n",
+ " {\"I_\": 1000, \"J_\": 25, \"K_\": 1000, \"L_\": 1000, \"DENSITY\": 0.001},\n",
+ "]\n",
+ "nonzeros = [100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000]\n",
+ "\n",
+ "if CI_MODE:\n",
+ " configs = configs[:1]\n",
+ " nonzeros = nonzeros[:1]\n",
+ "\n",
+ "finch_times = []\n",
+ "numba_times = []\n",
+ "finch_galley_times = []\n",
+ "\n",
+ "for config in configs:\n",
+ " B_shape = (config[\"I_\"], config[\"K_\"], config[\"L_\"])\n",
+ " B_sps = sparse.random(B_shape, density=config[\"DENSITY\"], random_state=rng)\n",
+ " D_sps = rng.random((config[\"L_\"], config[\"J_\"]))\n",
+ " C_sps = rng.random((config[\"K_\"], config[\"J_\"]))\n",
+ "\n",
+ " # ======= Finch =======\n",
+ " os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
+ " importlib.reload(sparse)\n",
+ "\n",
+ " B = sparse.asarray(B_sps.todense(), format=\"csf\")\n",
+ " D = sparse.asarray(np.array(D_sps, order=\"F\"))\n",
+ " C = sparse.asarray(np.array(C_sps, order=\"F\"))\n",
+ "\n",
+ " @sparse.compiled(opt=sparse.DefaultScheduler())\n",
+ " def mttkrp_finch(B, D, C):\n",
+ " return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
+ "\n",
+ " # Compile\n",
+ " result_finch = mttkrp_finch(B, D, C)\n",
+ " # Benchmark\n",
+ " time_finch = benchmark(mttkrp_finch, info=\"Finch\", args=[B, D, C])\n",
+ "\n",
+ " # ======= Finch Galley =======\n",
+ " os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
+ " importlib.reload(sparse)\n",
+ "\n",
+ " B = sparse.asarray(B_sps.todense(), format=\"csf\")\n",
+ " D = sparse.asarray(np.array(D_sps, order=\"F\"))\n",
+ " C = sparse.asarray(np.array(C_sps, order=\"F\"))\n",
+ "\n",
+ " @sparse.compiled(opt=sparse.GalleyScheduler(), tag=sum(B_shape))\n",
+ " def mttkrp_finch_galley(B, D, C):\n",
+ " return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
+ "\n",
+ " # Compile\n",
+ " result_finch_galley = mttkrp_finch_galley(B, D, C)\n",
+ " # Benchmark\n",
+ " time_finch_galley = benchmark(mttkrp_finch_galley, info=\"Finch Galley\", args=[B, D, C])\n",
+ "\n",
+ " # ======= Numba =======\n",
+ " os.environ[sparse._ENV_VAR_NAME] = \"Numba\"\n",
+ " importlib.reload(sparse)\n",
+ "\n",
+ " B = sparse.asarray(B_sps, format=\"gcxs\")\n",
+ " D = D_sps\n",
+ " C = C_sps\n",
+ "\n",
+ " def mttkrp_numba(B, D, C):\n",
+ " return sparse.sum(B[:, :, :, None] * D[None, None, :, :] * C[None, :, None, :], axis=(1, 2))\n",
+ "\n",
+ " # Compile\n",
+ " result_numba = mttkrp_numba(B, D, C)\n",
+ " # Benchmark\n",
+ " time_numba = benchmark(mttkrp_numba, info=\"Numba\", args=[B, D, C])\n",
+ "\n",
+ " np.testing.assert_allclose(result_finch.todense(), result_numba.todense())\n",
+ "\n",
+ " finch_times.append(time_finch)\n",
+ " numba_times.append(time_numba)\n",
+ " finch_galley_times.append(time_finch_galley)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig, ax = plt.subplots(nrows=1, ncols=1)\n",
+ "\n",
+ "ax.plot(nonzeros, finch_times, \"o-\", label=\"Finch\")\n",
+ "ax.plot(nonzeros, numba_times, \"o-\", label=\"Numba\")\n",
+ "ax.plot(nonzeros, finch_galley_times, \"o-\", label=\"Finch - Galley\")\n",
+ "ax.grid(True)\n",
+ "ax.set_xlabel(\"no. of elements\")\n",
+ "ax.set_ylabel(\"time (sec)\")\n",
+ "ax.set_title(\"MTTKRP\")\n",
+ "ax.set_xscale(\"log\")\n",
+ "ax.set_yscale(\"log\")\n",
+ "ax.legend(loc=\"best\", numpoints=1)\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## SDDMM"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"SDDMM Example:\\n\")\n",
+ "\n",
+ "configs = [\n",
+ " {\"LEN\": 5000, \"DENSITY\": 0.00001},\n",
+ " {\"LEN\": 10000, \"DENSITY\": 0.00001},\n",
+ " {\"LEN\": 15000, \"DENSITY\": 0.00001},\n",
+ " {\"LEN\": 20000, \"DENSITY\": 0.00001},\n",
+ " {\"LEN\": 25000, \"DENSITY\": 0.00001},\n",
+ " {\"LEN\": 30000, \"DENSITY\": 0.00001},\n",
+ "]\n",
+ "size_n = [5000, 10000, 15000, 20000, 25000, 30000]\n",
+ "\n",
+ "if CI_MODE:\n",
+ " configs = configs[:1]\n",
+ " size_n = size_n[:1]\n",
+ "\n",
+ "finch_times = []\n",
+ "numba_times = []\n",
+ "scipy_times = []\n",
+ "finch_galley_times = []\n",
+ "\n",
+ "for config in configs:\n",
+ " LEN = config[\"LEN\"]\n",
+ " DENSITY = config[\"DENSITY\"]\n",
+ "\n",
+ " a_sps = rng.random((LEN, LEN))\n",
+ " b_sps = rng.random((LEN, LEN))\n",
+ " s_sps = sps.random(LEN, LEN, format=\"coo\", density=DENSITY, random_state=rng)\n",
+ " s_sps.sum_duplicates()\n",
+ "\n",
+ " # ======= Finch =======\n",
+ " print(\"finch\")\n",
+ " os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
+ " importlib.reload(sparse)\n",
+ "\n",
+ " s = sparse.asarray(s_sps)\n",
+ " a = sparse.asarray(a_sps)\n",
+ " b = sparse.asarray(b_sps)\n",
+ "\n",
+ " @sparse.compiled(opt=sparse.DefaultScheduler())\n",
+ " def sddmm_finch(s, a, b):\n",
+ " return s * (a @ b)\n",
+ "\n",
+ " # Compile\n",
+ " result_finch = sddmm_finch(s, a, b)\n",
+ " # Benchmark\n",
+ " time_finch = benchmark(sddmm_finch, info=\"Finch\", args=[s, a, b])\n",
+ "\n",
+ " # ======= Finch Galley =======\n",
+ " print(\"finch galley\")\n",
+ " os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
+ " importlib.reload(sparse)\n",
+ "\n",
+ " s = sparse.asarray(s_sps)\n",
+ " a = sparse.asarray(a_sps)\n",
+ " b = sparse.asarray(b_sps)\n",
+ "\n",
+ " @sparse.compiled(opt=sparse.GalleyScheduler(), tag=LEN)\n",
+ " def sddmm_finch_galley(s, a, b):\n",
+ " return s * (a @ b)\n",
+ "\n",
+ " # Compile\n",
+ " result_finch_galley = sddmm_finch_galley(s, a, b)\n",
+ " # Benchmark\n",
+ " time_finch_galley = benchmark(sddmm_finch_galley, info=\"Finch Galley\", args=[s, a, b])\n",
+ "\n",
+ " # ======= Numba =======\n",
+ " print(\"numba\")\n",
+ " os.environ[sparse._ENV_VAR_NAME] = \"Numba\"\n",
+ " importlib.reload(sparse)\n",
+ "\n",
+ " s = sparse.asarray(s_sps)\n",
+ " a = a_sps\n",
+ " b = b_sps\n",
+ "\n",
+ " def sddmm_numba(s, a, b):\n",
+ " return s * (a @ b)\n",
+ "\n",
+ " # Compile\n",
+ " result_numba = sddmm_numba(s, a, b)\n",
+ " # Benchmark\n",
+ " time_numba = benchmark(sddmm_numba, info=\"Numba\", args=[s, a, b])\n",
+ "\n",
+ " # ======= SciPy =======\n",
+ " print(\"scipy\")\n",
+ "\n",
+ " def sddmm_scipy(s, a, b):\n",
+ " return s.multiply(a @ b)\n",
+ "\n",
+ " s = s_sps.asformat(\"csr\")\n",
+ " a = a_sps\n",
+ " b = b_sps\n",
+ "\n",
+ " result_scipy = sddmm_scipy(s, a, b)\n",
+ " # Benchmark\n",
+ " time_scipy = benchmark(sddmm_scipy, info=\"SciPy\", args=[s, a, b])\n",
+ "\n",
+ " finch_times.append(time_finch)\n",
+ " numba_times.append(time_numba)\n",
+ " scipy_times.append(time_scipy)\n",
+ " finch_galley_times.append(time_finch_galley)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig, ax = plt.subplots(nrows=1, ncols=1)\n",
+ "\n",
+ "ax.plot(size_n, finch_times, \"o-\", label=\"Finch\")\n",
+ "ax.plot(size_n, numba_times, \"o-\", label=\"Numba\")\n",
+ "ax.plot(size_n, scipy_times, \"o-\", label=\"SciPy\")\n",
+ "ax.plot(size_n, finch_galley_times, \"o-\", label=\"Finch Galley\")\n",
+ "\n",
+ "ax.grid(True)\n",
+ "ax.set_xlabel(\"size N\")\n",
+ "ax.set_ylabel(\"time (sec)\")\n",
+ "ax.set_title(\"SDDMM\")\n",
+ "ax.legend(loc=\"best\", numpoints=1)\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Counting Triangles"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "print(\"Counting Triangles Example:\\n\")\n",
+ "\n",
+ "configs = [\n",
+ " {\"LEN\": 10000, \"DENSITY\": 0.001},\n",
+ " {\"LEN\": 15000, \"DENSITY\": 0.001},\n",
+ " {\"LEN\": 20000, \"DENSITY\": 0.001},\n",
+ " {\"LEN\": 25000, \"DENSITY\": 0.001},\n",
+ " {\"LEN\": 30000, \"DENSITY\": 0.001},\n",
+ " {\"LEN\": 35000, \"DENSITY\": 0.001},\n",
+ " {\"LEN\": 40000, \"DENSITY\": 0.001},\n",
+ " {\"LEN\": 45000, \"DENSITY\": 0.001},\n",
+ " {\"LEN\": 50000, \"DENSITY\": 0.001},\n",
+ "]\n",
+ "size_n = [10000, 15000, 20000, 25000, 30000, 35000, 40000, 45000, 50000]\n",
+ "\n",
+ "if CI_MODE:\n",
+ " configs = configs[:1]\n",
+ " size_n = size_n[:1]\n",
+ "\n",
+ "finch_times = []\n",
+ "finch_galley_times = []\n",
+ "networkx_times = []\n",
+ "scipy_times = []\n",
+ "\n",
+ "for config in configs:\n",
+ " LEN = config[\"LEN\"]\n",
+ " DENSITY = config[\"DENSITY\"]\n",
+ "\n",
+ " G = nx.gnp_random_graph(n=LEN, p=DENSITY)\n",
+ " a_sps = nx.to_scipy_sparse_array(G)\n",
+ "\n",
+ " # ======= Finch =======\n",
+ " print(\"finch\")\n",
+ " os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
+ " importlib.reload(sparse)\n",
+ "\n",
+ " a = sparse.asarray(a_sps)\n",
+ "\n",
+ " @sparse.compiled(opt=sparse.DefaultScheduler())\n",
+ " def ct_finch(a):\n",
+ " return sparse.sum(a @ a * a) / sparse.asarray(6)\n",
+ "\n",
+ " # Compile\n",
+ " result_finch = ct_finch(a)\n",
+ " # Benchmark\n",
+ " time_finch = benchmark(ct_finch, info=\"Finch\", args=[a])\n",
+ "\n",
+ " # ======= Finch Galley =======\n",
+ " print(\"finch galley\")\n",
+ " os.environ[sparse._ENV_VAR_NAME] = \"Finch\"\n",
+ " importlib.reload(sparse)\n",
+ "\n",
+ " a = sparse.asarray(a_sps)\n",
+ "\n",
+ " @sparse.compiled(opt=sparse.GalleyScheduler(), tag=LEN)\n",
+ " def ct_finch_galley(a):\n",
+ " return sparse.sum(a @ a * a) / sparse.asarray(6)\n",
+ "\n",
+ " # Compile\n",
+ " result_finch_galley = ct_finch_galley(a)\n",
+ " # Benchmark\n",
+ " time_finch_galley = benchmark(ct_finch_galley, info=\"Finch Galley\", args=[a])\n",
+ "\n",
+ " # ======= SciPy =======\n",
+ " print(\"scipy\")\n",
+ "\n",
+ " def ct_scipy(a):\n",
+ " return (a @ a * a).sum() / 6\n",
+ "\n",
+ " a = a_sps\n",
+ "\n",
+ " # Benchmark\n",
+ " time_scipy = benchmark(ct_scipy, info=\"SciPy\", args=[a])\n",
+ "\n",
+ " # ======= NetworkX =======\n",
+ " print(\"networkx\")\n",
+ "\n",
+ " def ct_networkx(a):\n",
+ " return sum(nx.triangles(a).values()) / 3\n",
+ "\n",
+ " a = G\n",
+ "\n",
+ " time_networkx = benchmark(ct_networkx, info=\"SciPy\", args=[a])\n",
+ "\n",
+ " finch_times.append(time_finch)\n",
+ " finch_galley_times.append(time_finch_galley)\n",
+ " networkx_times.append(time_networkx)\n",
+ " scipy_times.append(time_scipy)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig, ax = plt.subplots(nrows=1, ncols=1)\n",
+ "\n",
+ "ax.plot(size_n, finch_times, \"o-\", label=\"Finch\")\n",
+ "ax.plot(size_n, networkx_times, \"o-\", label=\"NetworkX\")\n",
+ "ax.plot(size_n, scipy_times, \"o-\", label=\"SciPy\")\n",
+ "ax.plot(size_n, finch_galley_times, \"o-\", label=\"Finch Galley\")\n",
+ "\n",
+ "ax.grid(True)\n",
+ "ax.set_xlabel(\"size N\")\n",
+ "ax.set_ylabel(\"time (sec)\")\n",
+ "ax.set_title(\"Counting Triangles\")\n",
+ "ax.legend(loc=\"best\", numpoints=1)\n",
+ "\n",
+ "plt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "sparse-dev",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/notebooks.md b/docs/notebooks.md
new file mode 100644
index 00000000..e69de29b
diff --git a/examples/sparse_finch.ipynb b/examples/sparse_finch.ipynb
index 785860df..cbf942b4 100644
--- a/examples/sparse_finch.ipynb
+++ b/examples/sparse_finch.ipynb
@@ -6,7 +6,7 @@
"source": [
"## Finch backend for `sparse`\n",
"\n",
- "\n",
+ "\n",
"
\n",
" to download and run."
]
diff --git a/mkdocs.yml b/mkdocs.yml
index 0e18f917..0d52e6de 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -52,6 +52,7 @@ plugins:
- gen-files:
scripts:
- scripts/gen_ref_pages.py
+ - scripts/copy_notebooks.py
- literate-nav
- mkdocstrings:
handlers:
@@ -73,8 +74,11 @@ plugins:
- mkdocs-jupyter:
include_source: true
- execute: true
+ execute: false
ignore: ["__init__.py", "utils.py", "gen_logo.py"]
+ include:
+ - examples/*.ipynb
+ - docs/examples/*.ipynb
nav:
- Home:
@@ -101,3 +105,6 @@ nav:
- completed-tasks.md
- changelog.md
- conduct.md
+ - Notebooks:
+ - notebooks.md
+ - examples/sparse_finch.ipynb
diff --git a/scripts/copy_notebooks.py b/scripts/copy_notebooks.py
new file mode 100644
index 00000000..837e56a2
--- /dev/null
+++ b/scripts/copy_notebooks.py
@@ -0,0 +1,10 @@
+import shutil
+from pathlib import Path
+
+source_dir = Path("examples")
+dest_dir = Path("docs/examples")
+
+dest_dir.mkdir(parents=True, exist_ok=True)
+
+for notebook in source_dir.glob("*.ipynb"):
+ shutil.copy2(notebook, dest_dir / notebook.name)