Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 83 additions & 35 deletions dot_ring/ring_proof/columns/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,15 @@
from dataclasses import dataclass
from typing import cast

from dot_ring.ring_proof.constants import (
MAX_RING_SIZE,
OMEGA,
S_PRIME,
SIZE,
Blinding_Base,
PaddingPoint,
SeedPoint,
)
from dot_ring.ring_proof.constants import DEFAULT_SIZE, MAX_RING_SIZE, OMEGAS, S_PRIME, Blinding_Base, PaddingPoint, SeedPoint
from dot_ring.ring_proof.curve.bandersnatch import TwistedEdwardCurve as TE
from dot_ring.ring_proof.helpers import Helpers as H
from dot_ring.ring_proof.params import RingProofParams
from dot_ring.ring_proof.pcs.kzg import KZG
from dot_ring.ring_proof.polynomial.interpolation import poly_interpolate_fft

h_vec = json.load(open(os.path.join(os.path.dirname(__file__), "h_vec.json")))
_H_VEC_DEFAULT = json.load(open(os.path.join(os.path.dirname(__file__), "h_vec.json")))
_H_VEC_DEFAULT = [tuple(pt) for pt in _H_VEC_DEFAULT]

Scalar = int
G1Point = tuple
Expand All @@ -31,12 +25,14 @@ class Column:
evals: list[int]
coeffs: list[int] | None = None
commitment: G1Point | None = None
size: int = SIZE
size: int = DEFAULT_SIZE

def interpolate(self, domain_omega: int = OMEGA, prime: int = S_PRIME) -> None:
def interpolate(self, domain_omega: int = OMEGAS[DEFAULT_SIZE], prime: int = S_PRIME) -> None:
"""Fill `self.coeffs` from `self.evals` using FFT interpolation."""
if self.coeffs is None:
self.evals += [0] * (SIZE - len(self.evals))
if len(self.evals) > self.size:
raise ValueError(f"{self.name} evals length {len(self.evals)} exceeds column size {self.size}")
self.evals += [0] * (self.size - len(self.evals))
self.coeffs = poly_interpolate_fft(self.evals, domain_omega, prime)

def commit(self) -> None:
Expand All @@ -48,15 +44,27 @@ def commit(self) -> None:

@dataclass(slots=True)
class PublicColumnBuilder:
size: int = SIZE
size: int = DEFAULT_SIZE
prime: int = S_PRIME
omega: int = OMEGA
omega: int = OMEGAS[DEFAULT_SIZE]
max_ring_size: int = MAX_RING_SIZE
padding_rows: int = 4

@classmethod
def from_params(cls, params: RingProofParams) -> PublicColumnBuilder:
return cls(
size=params.domain_size,
prime=params.prime,
omega=params.omega,
max_ring_size=params.max_ring_size,
padding_rows=params.padding_rows,
)

def _pad_ring_with_padding_point(self, pk_ring: list[tuple[int, int]], size: int = MAX_RING_SIZE) -> list[tuple[int, int]]:
def _pad_ring_with_padding_point(self, pk_ring: list[tuple[int, int]]) -> list[tuple[int, int]]:
"""Pad ring in‑place with the special padding point until size."""
# padding_sw = sw.from_twisted_edwards(PaddingPoint)
padding_sw = PaddingPoint
while len(pk_ring) < MAX_RING_SIZE:
while len(pk_ring) < self.max_ring_size:
pk_ring.append(padding_sw)
return pk_ring

Expand All @@ -70,23 +78,31 @@ def _h_vector(self, blinding_base: tuple[int, int] = Blinding_Base) -> list[tupl

def build(self, ring_pk: list[tuple[int, int]]) -> tuple[Column, Column, Column]:
"""Return (Px, Py, s) columns fully committed."""
if len(ring_pk) < MAX_RING_SIZE:
if len(ring_pk) < self.max_ring_size:
ring_pk = self._pad_ring_with_padding_point(ring_pk)
if len(ring_pk) > self.size - self.padding_rows:
raise ValueError(f"ring size {len(ring_pk)} exceeds max supported size {self.size - self.padding_rows}")
# 1. ensure ring size
for i in range(self.size - 4 - len(ring_pk)):
ring_pk.append(h_vec[i])
ring_pk.extend([(0, 0)] * 4)
fill_count = self.size - self.padding_rows - len(ring_pk)
if fill_count > 0:
if self.size == len(_H_VEC_DEFAULT):
h_vec = _H_VEC_DEFAULT
else:
h_vec = self._h_vector()
ring_pk.extend(h_vec[:fill_count])
if self.padding_rows > 0:
ring_pk.extend([(0, 0)] * self.padding_rows)

# 2. unzip into x/y vectors
px, py = H.unzip(ring_pk)

# 3. selector vector
sel = [1 if i < MAX_RING_SIZE else 0 for i in range(self.size)]
sel = [1 if i < self.max_ring_size else 0 for i in range(self.size)]

# 4. Columns
col_px = Column("Px", px)
col_py = Column("Py", py)
col_s = Column("s", sel)
col_px = Column("Px", px, size=self.size)
col_py = Column("Py", py, size=self.size)
col_s = Column("s", sel, size=self.size)
for col in (col_px, col_py, col_s):
col.interpolate(self.omega, self.prime)
col.commit()
Expand All @@ -99,15 +115,45 @@ class WitnessColumnBuilder:
selector_vector: list[int]
producer_index: int
secret_t: int
size: int = SIZE
omega: int = OMEGA
size: int = DEFAULT_SIZE
omega: int = OMEGAS[DEFAULT_SIZE]
prime: int = S_PRIME
max_ring_size: int = MAX_RING_SIZE
padding_rows: int = 4

@classmethod
def from_params(
cls,
ring_pk: list[tuple[int, int]],
selector_vector: list[int],
producer_index: int,
secret_t: int,
params: RingProofParams,
) -> WitnessColumnBuilder:
return cls(
ring_pk=ring_pk,
selector_vector=selector_vector,
producer_index=producer_index,
secret_t=secret_t,
size=params.domain_size,
omega=params.omega,
prime=params.prime,
max_ring_size=params.max_ring_size,
padding_rows=params.padding_rows,
)

def _bits_vector(self) -> list[int]:
bv = [1 if i == self.producer_index else 0 for i in range(MAX_RING_SIZE)]
bv = [1 if i == self.producer_index else 0 for i in range(self.max_ring_size)]
t_bits = bin(self.secret_t)[2:][::-1]
bv.extend(int(b) for b in t_bits)
while len(bv) < self.size - 4:
pad_to = self.size - self.padding_rows
if len(bv) > pad_to:
raise ValueError(
"b vector length exceeds available rows: "
f"{len(bv)} > {pad_to} (ring_size={self.max_ring_size}, "
f"secret_t_bits={len(t_bits)}, padding_rows={self.padding_rows})"
)
while len(bv) < pad_to:
bv.append(0)
bv.append(0) # padding bit
return bv
Expand All @@ -116,14 +162,16 @@ def _conditional_sum_accumulator(self, b_vector: list[int]) -> tuple[list[int],
seed_sw = SeedPoint

acc = [seed_sw]
for i in range(1, self.size - 3):
acc_len = self.size - self.padding_rows + 1
for i in range(1, acc_len):
next_pt = acc[i - 1] if b_vector[i - 1] == 0 else cast(tuple[int, int], TE.add(acc[i - 1], self.ring_pk[i - 1]))
acc.append(next_pt)
return H.unzip(acc)

def _inner_product_accumulator(self, b_vector: list[int]) -> list[int]:
acc = [0]
for i in range(1, self.size - 3):
acc_len = self.size - self.padding_rows + 1
for i in range(1, acc_len):
acc.append(acc[i - 1] + b_vector[i - 1] * self.selector_vector[i - 1])
return acc

Expand All @@ -133,10 +181,10 @@ def build(self) -> tuple[Column, Column, Column, Column]:
acc_ip = self._inner_product_accumulator(b_vec)

columns = [
Column("b", b_vec),
Column("accx", acc_x),
Column("accy", acc_y),
Column("accip", acc_ip),
Column("b", b_vec, size=self.size),
Column("accx", acc_x, size=self.size),
Column("accy", acc_y, size=self.size),
Column("accip", acc_ip, size=self.size),
]
for col in columns:
col.interpolate(self.omega, self.prime)
Expand Down
35 changes: 19 additions & 16 deletions dot_ring/ring_proof/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,28 @@
S_A: int = 10773120815616481058602537765553212789256758185246796157495669123169359657269
S_B: int = 29569587568322301171008055308580903175558631321415017492731745847794083609535

DEFAULT_SIZE: int = 512

OMEGA_2048: int = 49307615728544765012166121802278658070711169839041683575071795236746050763237


# 512‑th root
OMEGA_USED: int = 4214636447306890335450803789410475782380792963881561516561680164772024173390

# Compute the 512‑th root ourselves to cross‑check
SIZE: int = 512 # FFT domain size for witness polynomials
OMEGA: int = pow(OMEGA_2048, 2048 // SIZE, S_PRIME)


# if OMEGA != OMEGA_USED: # Guardrail to detect accidental param drift
# raise ValueError("Computed 512‑th root does not match reference value")
OMEGA_1024 = pow(OMEGA_2048, 2048 // 1024, S_PRIME)
OMEGA_512: int = pow(OMEGA_2048, 2048 // 512, S_PRIME)

# Pre‑compute the entire evaluation domain for fast access.
D_512: list[int] = [pow(OMEGA, i, S_PRIME) for i in range(SIZE)]
D_512: list[int] = [pow(OMEGA_512, i, S_PRIME) for i in range(512)]
D_1024: list[int] = [pow(OMEGA_1024, i, S_PRIME) for i in range(1024)]
D_2048: list[int] = [pow(OMEGA_2048, i, S_PRIME) for i in range(2048)]

OMEGAS = {
512: OMEGA_512,
1024: OMEGA_1024,
2048: OMEGA_2048,
}

EVAL_DOMAINS = {
512: D_512,
1024: D_1024,
2048: D_2048,
}

MAX_RING_SIZE: int = 255 # Upper bound enforced by the constraint system

Expand All @@ -69,9 +72,9 @@
"S_A",
"S_B",
"OMEGA_2048",
"OMEGA_USED",
"OMEGA",
"SIZE",
"OMEGA_1024",
"OMEGA_512",
"DEFAULT_SIZE",
"D_512",
"D_2048",
"MAX_RING_SIZE",
Expand Down
16 changes: 8 additions & 8 deletions dot_ring/ring_proof/constraints/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
from collections.abc import Sequence

from dot_ring.curve.native_field.vector_ops import vect_add
from dot_ring.ring_proof.constants import D_512 as D
from dot_ring.ring_proof.constants import S_PRIME
from dot_ring.ring_proof.polynomial.interpolation import (
poly_interpolate_fft,
poly_mul_fft,
)
from dot_ring.ring_proof.params import RingProofParams
from dot_ring.ring_proof.polynomial.interpolation import poly_interpolate_fft
from dot_ring.ring_proof.polynomial.ops import (
poly_multiply,
vect_scalar_mul,
Expand All @@ -20,10 +17,10 @@
]


def vanishing_poly(k: int, omega_root: int, prime: int = S_PRIME) -> list[int]:
def vanishing_poly(domain: list[int], k: int = 3, prime: int = S_PRIME) -> list[int]:
vanishing_term = [1]
for i in range(1, k + 1):
vanishing_term = poly_mul_fft(vanishing_term, [-D[-i], 1], prime)
vanishing_term = poly_multiply(vanishing_term, [-domain[-i], 1], prime)
return vanishing_term


Expand All @@ -33,6 +30,7 @@ def aggregate_constraints(
omega_root: int,
prime: int = S_PRIME,
k: int = 3,
domain: list[int] | None = None,
) -> list[int]:
result = [0] * len(polys[0])
for poly, alpha in zip(polys, alphas, strict=False):
Expand All @@ -41,7 +39,9 @@ def aggregate_constraints(
interpolated_result = poly_interpolate_fft(result, omega_root, prime)

# get vanishing ply
v_t = vanishing_poly(k, omega_root, prime)
if domain is None:
domain = RingProofParams().domain
v_t = vanishing_poly(domain, k, prime)
# mul with c_agg
final_cs_agg = poly_multiply(interpolated_result, v_t, prime)

Expand Down
Loading
Loading