Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions examples/onnx/pla_sigmoid/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# PLA Sigmoid: Replace Lookup with Piecewise Linear (Copy/Arithmetic Constraints)

This example shows how to **avoid Sigmoid lookup tables** in ezkl by using a **piecewise linear approximation (PLA)** that uses only multiplications, additions, and comparisons. The circuit then relies on **copy and arithmetic constraints** instead of lookups, which can reduce proof cost and simplify settings.

## Why this helps

- **Sigmoid** is often implemented in ezkl via a **lookup table**, which can be expensive in circuit size and proving time.
- Replacing it with a **PLA** (segment-wise linear: `y = slope_i * x + intercept_i` per segment) gives:
- **No lookup**: only mul/add/compare ops → no `lookup_range` or large lookup tables.
- **Lighter settings**: e.g. `num_inner_cols=1`, `lookup_range=(0,1)`, `bounded_log_lookup=False`.
- **Faster proving** in practice when the rest of the circuit is comparable.

## Accuracy

- The PLA is fitted so that **max relative error vs true sigmoid ≤ 0.5%** (i.e. **≥ 99.5% accuracy**) over the chosen input range (default `[-2, 2]`).
- You can recalibrate with `python pwl_sigmoid.py` (writes `pwl_params.npz`).

## Files

| File | Description |
|------|-------------|
| `pwl_sigmoid.py` | PLA fitting and `PWLSigmoid` PyTorch module (mul/add/compare only). |
| `gen.py` | Builds the model (conv + PLA Sigmoid), exports ONNX and `input.json`. |
| `network.onnx` | Exported model (no Sigmoid op; PLA is expanded into linear ops). |
| `input.json` | Sample inputs and shapes for ezkl. |
| `pwl_params.npz` | Pre-fitted PLA parameters (optional; `gen.py` can regenerate). |

## How to run

1. **Export ONNX and input (if not already present):**
```bash
cd examples/onnx/pla_sigmoid
pip install torch numpy
python gen.py
```
This produces `network.onnx` and `input.json` (and optionally `pwl_params.npz`).

2. **Use with ezkl** (no lookup needed; lightweight settings):
```python
import ezkl
py_run_args = ezkl.PyRunArgs()
py_run_args.input_visibility = "public"
py_run_args.output_visibility = "public"
py_run_args.param_visibility = "fixed"
py_run_args.num_inner_cols = 1
py_run_args.lookup_range = (0, 1)
py_run_args.bounded_log_lookup = False
py_run_args.logrows = 16 # or as required by your circuit size
ezkl.gen_settings("network.onnx", "settings.json", py_run_args=py_run_args)
ezkl.compile_circuit("network.onnx", "network.compiled", "settings.json")
# Then: SRS, setup, gen_witness, prove as in the main ezkl docs.
```

## Author

[changshenhan](https://github.com/changshenhan) — PLA-on-Sigmoid example for ezkl.
85 changes: 85 additions & 0 deletions examples/onnx/pla_sigmoid/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Export ONNX model that uses Piecewise Linear (PLA) Sigmoid instead of Lookup.

This example shows how to avoid expensive Sigmoid lookup in ezkl by replacing it
with a PLA that uses only mul/add/compare — expressible with copy/arithmetic constraints.
Run from this directory: python gen.py
"""
import os
import numpy as np
import torch
import torch.onnx
import torch.nn as nn
import torch.nn.init as init
import json

from pwl_sigmoid import PWLSigmoid, calibrate_and_save, fit_pwl_sigmoid


def _load_pwl_sigmoid():
path = os.path.join(os.path.dirname(__file__), "pwl_params.npz")
if os.path.exists(path):
d = np.load(path)
return PWLSigmoid(d["breakpoints"], d["slopes"], d["intercepts"])
breakpoints, slopes, intercepts = fit_pwl_sigmoid(n_segments=16, target_max_rel_error=0.005)
np.savez(path, breakpoints=breakpoints, slopes=slopes, intercepts=intercepts)
return PWLSigmoid(breakpoints, slopes, intercepts)


class Circuit(nn.Module):
"""Small circuit: conv + PLA Sigmoid (no lookup)."""

def __init__(self):
super().__init__()
self.relu = nn.ReLU()
self.act = _load_pwl_sigmoid()
self.conv = nn.Conv2d(3, 3, (2, 2), 1, 2)
init.orthogonal_(self.conv.weight)

def forward(self, x, y, z):
x = self.act(self.conv(y @ x**2 + (x) - (self.relu(z)))) + 2
return (x, self.relu(z) / 3)


def main():
if not os.path.exists(os.path.join(os.path.dirname(__file__), "pwl_params.npz")):
calibrate_and_save(target_accuracy=0.995, out_path=os.path.join(os.path.dirname(__file__), "pwl_params.npz"))
torch_model = Circuit()
shape = [3, 2, 2]
x = 0.1 * torch.rand(1, *shape, requires_grad=True)
y = 0.1 * torch.rand(1, *shape, requires_grad=True)
z = 0.1 * torch.rand(1, *shape, requires_grad=True)
torch_out = torch_model(x, y, z)

out_dir = os.path.dirname(__file__)
onnx_path = os.path.join(out_dir, "network.onnx")
input_path = os.path.join(out_dir, "input.json")

torch.onnx.export(
torch_model,
(x, y, z),
onnx_path,
export_params=True,
opset_version=10,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
)

data = dict(
input_shapes=[shape, shape, shape],
input_data=[
x.detach().numpy().reshape(-1).tolist(),
y.detach().numpy().reshape(-1).tolist(),
z.detach().numpy().reshape(-1).tolist(),
],
output_data=[o.detach().numpy().reshape(-1).tolist() for o in torch_out],
)
with open(input_path, "w") as f:
json.dump(data, f)
print("Exported network.onnx and input.json")


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions examples/onnx/pla_sigmoid/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_shapes": [[3, 2, 2], [3, 2, 2], [3, 2, 2]], "input_data": [[0.006402426864951849, 0.0546477846801281, 0.03791030868887901, 0.057746853679418564, 0.03026706539094448, 0.05770156532526016, 0.07304316759109497, 0.07641308754682541, 0.06932983547449112, 0.0041304826736450195, 0.009861445985734463, 0.07692736387252808], [0.0415552519261837, 0.013548016548156738, 0.046301256865262985, 0.09327216446399689, 0.045912325382232666, 0.06049604341387749, 0.07569151371717453, 0.07194950431585312, 0.013698828406631947, 0.09671473503112793, 0.08251019567251205, 0.04880925640463829], [0.000953847193159163, 0.06371661275625229, 0.09247473627328873, 0.04223586246371269, 0.01274173241108656, 0.09394653886556625, 0.035997193306684494, 0.011770474724471569, 0.038802023977041245, 0.0823185071349144, 0.07291588932275772, 0.0731731653213501]], "output_data": [[2.531947135925293, 2.531947135925293, 2.531947135925293, 2.531947135925293, 2.531947135925293, 2.531947135925293, 2.530578136444092, 2.5382046699523926, 2.5266213417053223, 2.531947135925293, 2.531947135925293, 2.5237579345703125, 2.540153980255127, 2.5347211360931396, 2.531947135925293, 2.531947135925293, 2.543057680130005, 2.5285441875457764, 2.5347025394439697, 2.531947135925293, 2.531947135925293, 2.531947135925293, 2.531947135925293, 2.531947135925293, 2.531947135925293, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.527876853942871, 2.5258495807647705, 2.5168540477752686, 2.525843858718872, 2.525843858718872, 2.529510021209717, 2.527862071990967, 2.529017925262451, 2.525843858718872, 2.525843858718872, 2.521322250366211, 2.5252134799957275, 2.5183987617492676, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.553361654281616, 2.553361654281616, 2.553361654281616, 2.553361654281616, 2.553361654281616, 2.553361654281616, 2.5534555912017822, 2.5507540702819824, 2.5608303546905518, 2.553361654281616, 2.553361654281616, 2.561260461807251, 2.551497459411621, 2.547823667526245, 2.553361654281616, 2.553361654281616, 2.5519490242004395, 2.565648078918457, 2.5595531463623047, 2.553361654281616, 2.553361654281616, 2.553361654281616, 2.553361654281616, 2.553361654281616, 2.553361654281616], [0.0003179490740876645, 0.021238870918750763, 0.03082491271197796, 0.014078620821237564, 0.004247243981808424, 0.03131551295518875, 0.01199906412512064, 0.003923491574823856, 0.012934007681906223, 0.027439502999186516, 0.024305297061800957, 0.0243910551071167]]}
Binary file added examples/onnx/pla_sigmoid/network.onnx
Binary file not shown.
Binary file added examples/onnx/pla_sigmoid/pwl_params.npz
Binary file not shown.
122 changes: 122 additions & 0 deletions examples/onnx/pla_sigmoid/pwl_sigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Piecewise Linear Approximation (PLA) for Sigmoid — replaces Lookup Table in ZK circuits.

- Uses only multiplications, additions, and comparisons (no exp, no lookup).
- Enables circuits to use copy/arithmetic constraints instead of lookup tables.
- Target accuracy: >= 99.5% (max relative error <= 0.5%).
"""
import numpy as np
import torch
import torch.nn as nn
from typing import Tuple

DEFAULT_X_MIN = -2.0
DEFAULT_X_MAX = 2.0


def sigmoid(x: np.ndarray) -> np.ndarray:
return 1.0 / (1.0 + np.exp(-np.clip(x, -20, 20)))


def fit_pwl_sigmoid(
x_min: float = DEFAULT_X_MIN,
x_max: float = DEFAULT_X_MAX,
n_segments: int = 16,
n_samples_per_seg: int = 200,
target_max_rel_error: float = 0.005,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Fit piecewise linear segments to sigmoid; returns breakpoints, slopes, intercepts."""
breakpoints = np.linspace(x_min, x_max, n_segments + 1)
slopes = np.zeros(n_segments)
intercepts = np.zeros(n_segments)

for i in range(n_segments):
left, right = breakpoints[i], breakpoints[i + 1]
xs = np.linspace(left, right, n_samples_per_seg)
ys = sigmoid(xs)
A = np.stack([xs, np.ones_like(xs)], axis=1)
(a, b), _, _, _ = np.linalg.lstsq(A, ys, rcond=None)
slopes[i] = a
intercepts[i] = b

x_check = np.linspace(x_min, x_max, 2000)
y_true = sigmoid(x_check)
y_pwl = np.zeros_like(x_check)
for i in range(n_segments):
mask = (x_check >= breakpoints[i]) & (x_check < breakpoints[i + 1])
if i == n_segments - 1:
mask = (x_check >= breakpoints[i]) & (x_check <= breakpoints[i + 1])
y_pwl[mask] = slopes[i] * x_check[mask] + intercepts[i]
denom = np.maximum(np.abs(y_true), 1e-8)
rel_err = np.abs(y_pwl - y_true) / denom
max_rel = float(np.max(rel_err))

if max_rel > target_max_rel_error and n_segments < 64:
return fit_pwl_sigmoid(
x_min, x_max, n_segments=min(n_segments + 8, 64),
n_samples_per_seg=n_samples_per_seg,
target_max_rel_error=target_max_rel_error,
)
return breakpoints, slopes, intercepts


def pwl_sigmoid_numpy(x, breakpoints, slopes, intercepts):
"""NumPy PWL sigmoid for verification."""
out = np.zeros_like(x, dtype=np.float64)
n_segments = len(slopes)
for i in range(n_segments):
left, right = breakpoints[i], breakpoints[i + 1]
mask = (x >= left) & (x <= right) if i == n_segments - 1 else (x >= left) & (x < right)
out[mask] = slopes[i] * x[mask] + intercepts[i]
out[x < breakpoints[0]] = sigmoid(x[x < breakpoints[0]])
out[x > breakpoints[-1]] = sigmoid(x[x > breakpoints[-1]])
return out


class PWLSigmoid(nn.Module):
"""Piecewise linear Sigmoid: only mul/add/compare, no lookup. ZK-friendly."""

def __init__(self, breakpoints: np.ndarray, slopes: np.ndarray, intercepts: np.ndarray):
super().__init__()
self.register_buffer("breakpoints", torch.tensor(breakpoints, dtype=torch.float32))
self.register_buffer("slopes", torch.tensor(slopes, dtype=torch.float32))
self.register_buffer("intercepts", torch.tensor(intercepts, dtype=torch.float32))
self.n_segments = len(slopes)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.clamp(x, self.breakpoints[0], self.breakpoints[-1])
out = torch.zeros_like(x)
for i in range(self.n_segments):
left = self.breakpoints[i]
right = self.breakpoints[i + 1]
mask = (x >= left) & (x <= right) if i == self.n_segments - 1 else (x >= left) & (x < right)
out = out + mask.float() * (self.slopes[i] * x + self.intercepts[i])
return out


def calibrate_and_save(
x_min: float = DEFAULT_X_MIN,
x_max: float = DEFAULT_X_MAX,
target_accuracy: float = 0.995,
out_path: str = "pwl_params.npz",
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Calibrate PWL to meet target accuracy and save to npz."""
target_max_rel = 1.0 - target_accuracy
for n in [8, 16, 32, 64]:
try:
bp, sl, ic = fit_pwl_sigmoid(x_min, x_max, n_segments=n, target_max_rel_error=target_max_rel)
x_check = np.linspace(x_min, x_max, 2000)
y_true = sigmoid(x_check)
y_pwl = pwl_sigmoid_numpy(x_check, bp, sl, ic)
rel_err = np.abs(y_pwl - y_true) / np.maximum(np.abs(y_true), 1e-8)
max_rel = float(np.max(rel_err))
if (1.0 - max_rel) >= target_accuracy:
np.savez(out_path, breakpoints=bp, slopes=sl, intercepts=ic)
return bp, sl, ic
except Exception:
continue
raise RuntimeError(f"Could not reach {target_accuracy*100}% accuracy with 8--64 segments")


if __name__ == "__main__":
calibrate_and_save(target_accuracy=0.995)