diff --git a/examples/onnx/pla_sigmoid/README.md b/examples/onnx/pla_sigmoid/README.md new file mode 100644 index 000000000..38676af00 --- /dev/null +++ b/examples/onnx/pla_sigmoid/README.md @@ -0,0 +1,56 @@ +# PLA Sigmoid: Replace Lookup with Piecewise Linear (Copy/Arithmetic Constraints) + +This example shows how to **avoid Sigmoid lookup tables** in ezkl by using a **piecewise linear approximation (PLA)** that uses only multiplications, additions, and comparisons. The circuit then relies on **copy and arithmetic constraints** instead of lookups, which can reduce proof cost and simplify settings. + +## Why this helps + +- **Sigmoid** is often implemented in ezkl via a **lookup table**, which can be expensive in circuit size and proving time. +- Replacing it with a **PLA** (segment-wise linear: `y = slope_i * x + intercept_i` per segment) gives: + - **No lookup**: only mul/add/compare ops → no `lookup_range` or large lookup tables. + - **Lighter settings**: e.g. `num_inner_cols=1`, `lookup_range=(0,1)`, `bounded_log_lookup=False`. + - **Faster proving** in practice when the rest of the circuit is comparable. + +## Accuracy + +- The PLA is fitted so that **max relative error vs true sigmoid ≤ 0.5%** (i.e. **≥ 99.5% accuracy**) over the chosen input range (default `[-2, 2]`). +- You can recalibrate with `python pwl_sigmoid.py` (writes `pwl_params.npz`). + +## Files + +| File | Description | +|------|-------------| +| `pwl_sigmoid.py` | PLA fitting and `PWLSigmoid` PyTorch module (mul/add/compare only). | +| `gen.py` | Builds the model (conv + PLA Sigmoid), exports ONNX and `input.json`. | +| `network.onnx` | Exported model (no Sigmoid op; PLA is expanded into linear ops). | +| `input.json` | Sample inputs and shapes for ezkl. | +| `pwl_params.npz` | Pre-fitted PLA parameters (optional; `gen.py` can regenerate). | + +## How to run + +1. **Export ONNX and input (if not already present):** + ```bash + cd examples/onnx/pla_sigmoid + pip install torch numpy + python gen.py + ``` + This produces `network.onnx` and `input.json` (and optionally `pwl_params.npz`). + +2. **Use with ezkl** (no lookup needed; lightweight settings): + ```python + import ezkl + py_run_args = ezkl.PyRunArgs() + py_run_args.input_visibility = "public" + py_run_args.output_visibility = "public" + py_run_args.param_visibility = "fixed" + py_run_args.num_inner_cols = 1 + py_run_args.lookup_range = (0, 1) + py_run_args.bounded_log_lookup = False + py_run_args.logrows = 16 # or as required by your circuit size + ezkl.gen_settings("network.onnx", "settings.json", py_run_args=py_run_args) + ezkl.compile_circuit("network.onnx", "network.compiled", "settings.json") + # Then: SRS, setup, gen_witness, prove as in the main ezkl docs. + ``` + +## Author + +[changshenhan](https://github.com/changshenhan) — PLA-on-Sigmoid example for ezkl. diff --git a/examples/onnx/pla_sigmoid/gen.py b/examples/onnx/pla_sigmoid/gen.py new file mode 100644 index 000000000..44d2b5933 --- /dev/null +++ b/examples/onnx/pla_sigmoid/gen.py @@ -0,0 +1,85 @@ +""" +Export ONNX model that uses Piecewise Linear (PLA) Sigmoid instead of Lookup. + +This example shows how to avoid expensive Sigmoid lookup in ezkl by replacing it +with a PLA that uses only mul/add/compare — expressible with copy/arithmetic constraints. +Run from this directory: python gen.py +""" +import os +import numpy as np +import torch +import torch.onnx +import torch.nn as nn +import torch.nn.init as init +import json + +from pwl_sigmoid import PWLSigmoid, calibrate_and_save, fit_pwl_sigmoid + + +def _load_pwl_sigmoid(): + path = os.path.join(os.path.dirname(__file__), "pwl_params.npz") + if os.path.exists(path): + d = np.load(path) + return PWLSigmoid(d["breakpoints"], d["slopes"], d["intercepts"]) + breakpoints, slopes, intercepts = fit_pwl_sigmoid(n_segments=16, target_max_rel_error=0.005) + np.savez(path, breakpoints=breakpoints, slopes=slopes, intercepts=intercepts) + return PWLSigmoid(breakpoints, slopes, intercepts) + + +class Circuit(nn.Module): + """Small circuit: conv + PLA Sigmoid (no lookup).""" + + def __init__(self): + super().__init__() + self.relu = nn.ReLU() + self.act = _load_pwl_sigmoid() + self.conv = nn.Conv2d(3, 3, (2, 2), 1, 2) + init.orthogonal_(self.conv.weight) + + def forward(self, x, y, z): + x = self.act(self.conv(y @ x**2 + (x) - (self.relu(z)))) + 2 + return (x, self.relu(z) / 3) + + +def main(): + if not os.path.exists(os.path.join(os.path.dirname(__file__), "pwl_params.npz")): + calibrate_and_save(target_accuracy=0.995, out_path=os.path.join(os.path.dirname(__file__), "pwl_params.npz")) + torch_model = Circuit() + shape = [3, 2, 2] + x = 0.1 * torch.rand(1, *shape, requires_grad=True) + y = 0.1 * torch.rand(1, *shape, requires_grad=True) + z = 0.1 * torch.rand(1, *shape, requires_grad=True) + torch_out = torch_model(x, y, z) + + out_dir = os.path.dirname(__file__) + onnx_path = os.path.join(out_dir, "network.onnx") + input_path = os.path.join(out_dir, "input.json") + + torch.onnx.export( + torch_model, + (x, y, z), + onnx_path, + export_params=True, + opset_version=10, + do_constant_folding=True, + input_names=["input"], + output_names=["output"], + dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, + ) + + data = dict( + input_shapes=[shape, shape, shape], + input_data=[ + x.detach().numpy().reshape(-1).tolist(), + y.detach().numpy().reshape(-1).tolist(), + z.detach().numpy().reshape(-1).tolist(), + ], + output_data=[o.detach().numpy().reshape(-1).tolist() for o in torch_out], + ) + with open(input_path, "w") as f: + json.dump(data, f) + print("Exported network.onnx and input.json") + + +if __name__ == "__main__": + main() diff --git a/examples/onnx/pla_sigmoid/input.json b/examples/onnx/pla_sigmoid/input.json new file mode 100644 index 000000000..7e31e4f49 --- /dev/null +++ b/examples/onnx/pla_sigmoid/input.json @@ -0,0 +1 @@ +{"input_shapes": [[3, 2, 2], [3, 2, 2], [3, 2, 2]], "input_data": [[0.006402426864951849, 0.0546477846801281, 0.03791030868887901, 0.057746853679418564, 0.03026706539094448, 0.05770156532526016, 0.07304316759109497, 0.07641308754682541, 0.06932983547449112, 0.0041304826736450195, 0.009861445985734463, 0.07692736387252808], [0.0415552519261837, 0.013548016548156738, 0.046301256865262985, 0.09327216446399689, 0.045912325382232666, 0.06049604341387749, 0.07569151371717453, 0.07194950431585312, 0.013698828406631947, 0.09671473503112793, 0.08251019567251205, 0.04880925640463829], [0.000953847193159163, 0.06371661275625229, 0.09247473627328873, 0.04223586246371269, 0.01274173241108656, 0.09394653886556625, 0.035997193306684494, 0.011770474724471569, 0.038802023977041245, 0.0823185071349144, 0.07291588932275772, 0.0731731653213501]], "output_data": [[2.531947135925293, 2.531947135925293, 2.531947135925293, 2.531947135925293, 2.531947135925293, 2.531947135925293, 2.530578136444092, 2.5382046699523926, 2.5266213417053223, 2.531947135925293, 2.531947135925293, 2.5237579345703125, 2.540153980255127, 2.5347211360931396, 2.531947135925293, 2.531947135925293, 2.543057680130005, 2.5285441875457764, 2.5347025394439697, 2.531947135925293, 2.531947135925293, 2.531947135925293, 2.531947135925293, 2.531947135925293, 2.531947135925293, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.527876853942871, 2.5258495807647705, 2.5168540477752686, 2.525843858718872, 2.525843858718872, 2.529510021209717, 2.527862071990967, 2.529017925262451, 2.525843858718872, 2.525843858718872, 2.521322250366211, 2.5252134799957275, 2.5183987617492676, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.525843858718872, 2.553361654281616, 2.553361654281616, 2.553361654281616, 2.553361654281616, 2.553361654281616, 2.553361654281616, 2.5534555912017822, 2.5507540702819824, 2.5608303546905518, 2.553361654281616, 2.553361654281616, 2.561260461807251, 2.551497459411621, 2.547823667526245, 2.553361654281616, 2.553361654281616, 2.5519490242004395, 2.565648078918457, 2.5595531463623047, 2.553361654281616, 2.553361654281616, 2.553361654281616, 2.553361654281616, 2.553361654281616, 2.553361654281616], [0.0003179490740876645, 0.021238870918750763, 0.03082491271197796, 0.014078620821237564, 0.004247243981808424, 0.03131551295518875, 0.01199906412512064, 0.003923491574823856, 0.012934007681906223, 0.027439502999186516, 0.024305297061800957, 0.0243910551071167]]} \ No newline at end of file diff --git a/examples/onnx/pla_sigmoid/network.onnx b/examples/onnx/pla_sigmoid/network.onnx new file mode 100644 index 000000000..da8370a32 Binary files /dev/null and b/examples/onnx/pla_sigmoid/network.onnx differ diff --git a/examples/onnx/pla_sigmoid/pwl_params.npz b/examples/onnx/pla_sigmoid/pwl_params.npz new file mode 100644 index 000000000..8c04ea7a1 Binary files /dev/null and b/examples/onnx/pla_sigmoid/pwl_params.npz differ diff --git a/examples/onnx/pla_sigmoid/pwl_sigmoid.py b/examples/onnx/pla_sigmoid/pwl_sigmoid.py new file mode 100644 index 000000000..7e789ae03 --- /dev/null +++ b/examples/onnx/pla_sigmoid/pwl_sigmoid.py @@ -0,0 +1,122 @@ +""" +Piecewise Linear Approximation (PLA) for Sigmoid — replaces Lookup Table in ZK circuits. + +- Uses only multiplications, additions, and comparisons (no exp, no lookup). +- Enables circuits to use copy/arithmetic constraints instead of lookup tables. +- Target accuracy: >= 99.5% (max relative error <= 0.5%). +""" +import numpy as np +import torch +import torch.nn as nn +from typing import Tuple + +DEFAULT_X_MIN = -2.0 +DEFAULT_X_MAX = 2.0 + + +def sigmoid(x: np.ndarray) -> np.ndarray: + return 1.0 / (1.0 + np.exp(-np.clip(x, -20, 20))) + + +def fit_pwl_sigmoid( + x_min: float = DEFAULT_X_MIN, + x_max: float = DEFAULT_X_MAX, + n_segments: int = 16, + n_samples_per_seg: int = 200, + target_max_rel_error: float = 0.005, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Fit piecewise linear segments to sigmoid; returns breakpoints, slopes, intercepts.""" + breakpoints = np.linspace(x_min, x_max, n_segments + 1) + slopes = np.zeros(n_segments) + intercepts = np.zeros(n_segments) + + for i in range(n_segments): + left, right = breakpoints[i], breakpoints[i + 1] + xs = np.linspace(left, right, n_samples_per_seg) + ys = sigmoid(xs) + A = np.stack([xs, np.ones_like(xs)], axis=1) + (a, b), _, _, _ = np.linalg.lstsq(A, ys, rcond=None) + slopes[i] = a + intercepts[i] = b + + x_check = np.linspace(x_min, x_max, 2000) + y_true = sigmoid(x_check) + y_pwl = np.zeros_like(x_check) + for i in range(n_segments): + mask = (x_check >= breakpoints[i]) & (x_check < breakpoints[i + 1]) + if i == n_segments - 1: + mask = (x_check >= breakpoints[i]) & (x_check <= breakpoints[i + 1]) + y_pwl[mask] = slopes[i] * x_check[mask] + intercepts[i] + denom = np.maximum(np.abs(y_true), 1e-8) + rel_err = np.abs(y_pwl - y_true) / denom + max_rel = float(np.max(rel_err)) + + if max_rel > target_max_rel_error and n_segments < 64: + return fit_pwl_sigmoid( + x_min, x_max, n_segments=min(n_segments + 8, 64), + n_samples_per_seg=n_samples_per_seg, + target_max_rel_error=target_max_rel_error, + ) + return breakpoints, slopes, intercepts + + +def pwl_sigmoid_numpy(x, breakpoints, slopes, intercepts): + """NumPy PWL sigmoid for verification.""" + out = np.zeros_like(x, dtype=np.float64) + n_segments = len(slopes) + for i in range(n_segments): + left, right = breakpoints[i], breakpoints[i + 1] + mask = (x >= left) & (x <= right) if i == n_segments - 1 else (x >= left) & (x < right) + out[mask] = slopes[i] * x[mask] + intercepts[i] + out[x < breakpoints[0]] = sigmoid(x[x < breakpoints[0]]) + out[x > breakpoints[-1]] = sigmoid(x[x > breakpoints[-1]]) + return out + + +class PWLSigmoid(nn.Module): + """Piecewise linear Sigmoid: only mul/add/compare, no lookup. ZK-friendly.""" + + def __init__(self, breakpoints: np.ndarray, slopes: np.ndarray, intercepts: np.ndarray): + super().__init__() + self.register_buffer("breakpoints", torch.tensor(breakpoints, dtype=torch.float32)) + self.register_buffer("slopes", torch.tensor(slopes, dtype=torch.float32)) + self.register_buffer("intercepts", torch.tensor(intercepts, dtype=torch.float32)) + self.n_segments = len(slopes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.clamp(x, self.breakpoints[0], self.breakpoints[-1]) + out = torch.zeros_like(x) + for i in range(self.n_segments): + left = self.breakpoints[i] + right = self.breakpoints[i + 1] + mask = (x >= left) & (x <= right) if i == self.n_segments - 1 else (x >= left) & (x < right) + out = out + mask.float() * (self.slopes[i] * x + self.intercepts[i]) + return out + + +def calibrate_and_save( + x_min: float = DEFAULT_X_MIN, + x_max: float = DEFAULT_X_MAX, + target_accuracy: float = 0.995, + out_path: str = "pwl_params.npz", +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Calibrate PWL to meet target accuracy and save to npz.""" + target_max_rel = 1.0 - target_accuracy + for n in [8, 16, 32, 64]: + try: + bp, sl, ic = fit_pwl_sigmoid(x_min, x_max, n_segments=n, target_max_rel_error=target_max_rel) + x_check = np.linspace(x_min, x_max, 2000) + y_true = sigmoid(x_check) + y_pwl = pwl_sigmoid_numpy(x_check, bp, sl, ic) + rel_err = np.abs(y_pwl - y_true) / np.maximum(np.abs(y_true), 1e-8) + max_rel = float(np.max(rel_err)) + if (1.0 - max_rel) >= target_accuracy: + np.savez(out_path, breakpoints=bp, slopes=sl, intercepts=ic) + return bp, sl, ic + except Exception: + continue + raise RuntimeError(f"Could not reach {target_accuracy*100}% accuracy with 8--64 segments") + + +if __name__ == "__main__": + calibrate_and_save(target_accuracy=0.995)