Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: Lint

on:
pull_request:
branches: [main, dev]
push:
branches: [main]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "latest"

- name: Set up Python
run: uv python install 3.12

- name: Install dependencies
run: uv sync

- name: Run ruff check
run: uv run --with ruff ruff check .

- name: Run ruff format check
run: uv run --with ruff ruff format --check .
10 changes: 7 additions & 3 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
name: Build & Release

on:
pull_request:
branches: [main]
push:
branches: [main, fix/build-*]
branches: [fix/build-*]
tags:
- "v*.*.*"
- "*.*.*"
Expand All @@ -11,6 +13,7 @@ jobs:
build_wheels:
name: Build wheels on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
if: startsWith(github.ref, 'refs/tags/') || matrix.os == 'macos-latest'
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
Expand All @@ -26,7 +29,7 @@ jobs:
uses: pypa/[email protected]
env:
# Python versions to build
CIBW_SKIP: "pp* cp38-* cp39-* cp310-* cp311-* *-musllinux_*"
CIBW_SKIP: "pp* cp38-* cp39-* cp310-* *-musllinux_*"

# Use manylinux_2_28 (glibc 2.28+)
CIBW_MANYLINUX_X86_64_IMAGE: manylinux_2_28
Expand All @@ -42,7 +45,7 @@ jobs:
CIBW_ENVIRONMENT_MACOS: MACOSX_DEPLOYMENT_TARGET=11.0

# Install Python build dependencies
CIBW_BEFORE_BUILD: pip install Cython setuptools wheel
CIBW_BEFORE_BUILD: pip install Cython setuptools wheel "packaging>=24.2"

# Architectures to build
CIBW_ARCHS_LINUX: x86_64 aarch64
Expand All @@ -56,6 +59,7 @@ jobs:
build_sdist:
name: Build source distribution
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- uses: actions/checkout@v4

Expand Down
21 changes: 21 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-yaml
- id: check-toml
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
additional_dependencies: [types-setuptools]
31 changes: 6 additions & 25 deletions dot_ring/curve/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,9 @@ def _validate_parameters(self) -> bool:
# Original scalar field validation
# Allow int or custom Scalar types
if isinstance(self.GENERATOR_X, int) and isinstance(self.GENERATOR_Y, int):
if not (
0 <= self.GENERATOR_X < self.PRIME_FIELD
and 0 <= self.GENERATOR_Y < self.PRIME_FIELD
):
if not (0 <= self.GENERATOR_X < self.PRIME_FIELD and 0 <= self.GENERATOR_Y < self.PRIME_FIELD):
return False
elif not (
isinstance(self.GENERATOR_X, (tuple, list))
or isinstance(self.GENERATOR_Y, (tuple, list))
):
elif not (isinstance(self.GENERATOR_X, (tuple, list)) or isinstance(self.GENERATOR_Y, (tuple, list))):
# Assume custom field elements (like Scalar)
# We can't easily check bounds against int PRIME_FIELD if they are opaque,
# but we assume they are valid if they are passed.
Expand All @@ -124,12 +118,7 @@ def _validate_parameters(self) -> bool:
# if not self.is_on_curve(self.GENERATOR_X, self.GENERATOR_Y): #already given in point class
# return False

return (
self.PRIME_FIELD > 2
and self.ORDER > 2
and self.COFACTOR > 0
and self.PRIME_FIELD != self.ORDER
)
return self.PRIME_FIELD > 2 and self.ORDER > 2 and self.COFACTOR > 0 and self.PRIME_FIELD != self.ORDER

def hash_to_field(self, msg: bytes, count: int) -> list[int]:
"""
Expand Down Expand Up @@ -181,11 +170,7 @@ def expand_message_xmd(self, msg: bytes, len_in_bytes: int) -> bytes:
b_in_bytes = self.H_A().digest_size
ell = math.ceil(len_in_bytes / b_in_bytes)

if (
(ell > 255 and self.PRIME_FIELD.bit_length() < 384)
or len_in_bytes > 65535
or len(self.DST) > 255
):
if (ell > 255 and self.PRIME_FIELD.bit_length() < 384) or len_in_bytes > 65535 or len(self.DST) > 255:
# Relax ell check for large curves like P-521 where len_in_bytes might be large
# But strictly, RFC 9380 says ell <= 255.
# If len_in_bytes is huge, maybe we should just allow it if the curve is large?
Expand Down Expand Up @@ -214,9 +199,7 @@ def expand_message_xmd(self, msg: bytes, len_in_bytes: int) -> bytes:
# but 26KB seems wrong.
# Let's assume the user knows what they are doing if they request such length.
# But XMD structure relies on 1 byte for 'i' in loop. So ell cannot exceed 255.
raise ValueError(
f"Invalid input size parameters: ell={ell}, len={len_in_bytes}, dst_len={len(self.DST)}"
)
raise ValueError(f"Invalid input size parameters: ell={ell}, len={len_in_bytes}, dst_len={len(self.DST)}")

DST_prime = self.DST + self.I2OSP(len(self.DST), 1)
Z_pad = self.I2OSP(0, cast(int, self.S_in_bytes))
Expand All @@ -231,9 +214,7 @@ def expand_message_xmd(self, msg: bytes, len_in_bytes: int) -> bytes:

b_values = [b_1]
for i in range(2, ell + 1):
b_i = self.H_A(
self.strxor(b_0, b_values[-1]) + self.I2OSP(i, 1) + DST_prime
).digest()
b_i = self.H_A(self.strxor(b_0, b_values[-1]) + self.I2OSP(i, 1) + DST_prime).digest()
b_values.append(b_i)

uniform_bytes = b"".join(b_values)
Expand Down
16 changes: 4 additions & 12 deletions dot_ring/curve/field_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,15 @@ def __add__(self, other: FieldElement | int) -> FieldElement:
if isinstance(other, FieldElement):
if self.p != other.p:
raise ValueError("Cannot add elements from different fields")
return FieldElement(
(self.re + other.re) % self.p, (self.im + other.im) % self.p, self.p
)
return FieldElement((self.re + other.re) % self.p, (self.im + other.im) % self.p, self.p)
return FieldElement((self.re + other) % self.p, self.im, self.p)

def __sub__(self, other: FieldElement | int) -> FieldElement:
"""Subtract two field elements or a field element and an integer."""
if isinstance(other, FieldElement):
if self.p != other.p:
raise ValueError("Cannot subtract elements from different fields")
return FieldElement(
(self.re - other.re) % self.p, (self.im - other.im) % self.p, self.p
)
return FieldElement((self.re - other.re) % self.p, (self.im - other.im) % self.p, self.p)
return FieldElement((self.re - other) % self.p, self.im, self.p)

def __mul__(self, other: FieldElement | int) -> FieldElement:
Expand All @@ -58,9 +54,7 @@ def __mul__(self, other: FieldElement | int) -> FieldElement:
re = (self.re * other.re - self.im * other.im) % self.p
im = (self.re * other.im + self.im * other.re) % self.p
return FieldElement(re, im, self.p)
return FieldElement(
(self.re * other) % self.p, (self.im * other) % self.p, self.p
)
return FieldElement((self.re * other) % self.p, (self.im * other) % self.p, self.p)

def __truediv__(self, other: FieldElement | int) -> FieldElement:
"""Divide two field elements or a field element by an integer."""
Expand All @@ -74,9 +68,7 @@ def inv(self) -> FieldElement:
# For Fp2, the inverse of (a + bi) is (a - bi)/(a² + b²)
denom = (self.re * self.re + self.im * self.im) % self.p
inv_denom = pow(denom, -1, self.p)
return FieldElement(
(self.re * inv_denom) % self.p, (-self.im * inv_denom) % self.p, self.p
)
return FieldElement((self.re * inv_denom) % self.p, (-self.im * inv_denom) % self.p, self.p)

def __neg__(self) -> FieldElement:
"""Negate the field element."""
Expand Down
27 changes: 5 additions & 22 deletions dot_ring/curve/glv.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def _validate_parameters(self) -> bool:
"""
return self.lambda_param != 0 and self.constant_b != 0 and self.constant_c != 0

def extended_euclidean_algorithm(
self, n: int, lam: int
) -> list[tuple[int, int, int]]:
def extended_euclidean_algorithm(self, n: int, lam: int) -> list[tuple[int, int, int]]:
"""
Compute extended Euclidean algorithm sequence.

Expand Down Expand Up @@ -84,9 +82,7 @@ def extended_euclidean_algorithm(
return sequence[:-1]

@lru_cache(maxsize=1024) # noqa: B019
def find_short_vectors(
self, n: int, lam: int
) -> tuple[tuple[int, int], tuple[int, int]]:
def find_short_vectors(self, n: int, lam: int) -> tuple[tuple[int, int], tuple[int, int]]:
"""
Find short vectors for scalar decomposition.

Expand Down Expand Up @@ -187,9 +183,7 @@ def compute_endomorphism(self, point: AffinePointT) -> AffinePointT:

return point.__class__(x_a, y_a)

def windowed_simultaneous_mult(
self, k1: int, k2: int, P1: AffinePointT, P2: AffinePointT, w: int = 2
) -> AffinePointT:
def windowed_simultaneous_mult(self, k1: int, k2: int, P1: AffinePointT, P2: AffinePointT, w: int = 2) -> AffinePointT:
"""
Compute k1 * P1 + k2 * P2 using windowed simultaneous multi-scalar multiplication.

Expand Down Expand Up @@ -235,9 +229,7 @@ def windowed_simultaneous_mult(

assert projective_to_affine is not None
# Use compiled MSM
rx, ry, rz, rt = _compiled_msm(
k1, k2, P1.x, P1.y, 1, p1_t, P2.x, P2.y, 1, p2_t, a_coeff, d_coeff, p, w
)
rx, ry, rz, rt = _compiled_msm(k1, k2, P1.x, P1.y, 1, p1_t, P2.x, P2.y, 1, p2_t, a_coeff, d_coeff, p, w)

# Convert back to affine
ax, ay = projective_to_affine(rx, ry, rz, p)
Expand Down Expand Up @@ -288,16 +280,7 @@ def multi_scalar_mult_4(
d_coeff = P1.curve.EdwardsD

# Convert to projective coordinates
if (
P1.x is None
or P1.y is None
or P2.x is None
or P2.y is None
or P3.x is None
or P3.y is None
or P4.x is None
or P4.y is None
):
if P1.x is None or P1.y is None or P2.x is None or P2.y is None or P3.x is None or P3.y is None or P4.x is None or P4.y is None:
# Fallback to simple addition for identity points
res = P1 * k1 + P2 * k2 # type: ignore[operator]
res = res + P3 * k3 # type: ignore[operator]
Expand Down
28 changes: 7 additions & 21 deletions dot_ring/curve/montgomery/mg_affine_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ def __add__(self, other: MGAffinePoint[C]) -> MGAffinePoint[C]:
# if y == 0 then slope denominator = 0 => result is identity
if y1 % p == 0:
return self.__class__(None, None)
numerator = (
3 * cast(int, x1) * cast(int, x1) + 2 * cast(int, A) * cast(int, x1) + 1
) % p
numerator = (3 * cast(int, x1) * cast(int, x1) + 2 * cast(int, A) * cast(int, x1) + 1) % p
denominator = (2 * cast(int, B) * cast(int, y1)) % p
# Check if denominator is zero before computing inverse
if denominator == 0:
Expand All @@ -99,9 +97,7 @@ def __add__(self, other: MGAffinePoint[C]) -> MGAffinePoint[C]:
raise ValueError("Unexpected zero denominator in point addition")
lam = (numerator * pow(denominator, -1, p)) % p
# Corrected formula for x3 in point addition
x3 = (
cast(int, B) * lam * lam - cast(int, A) - cast(int, x1) - cast(int, x2)
) % p
x3 = (cast(int, B) * lam * lam - cast(int, A) - cast(int, x1) - cast(int, x2)) % p
# Corrected formula for y3
y3 = (lam * (cast(int, x1) - x3) - cast(int, y1)) % p
return self.__class__(x3, y3)
Expand Down Expand Up @@ -266,18 +262,14 @@ def encode_to_curve(
# Check if it's an ELL2 variant (ELL2 or ELL2_NU)
if cls.curve.E2C in (E2C_Variant.ELL2, E2C_Variant.ELL2_NU):
if cls.curve.E2C.value.endswith("_NU_"):
return cls.encode_to_curve_hash2_suite_nu(
alpha_string, salt, General_Check
)
return cls.encode_to_curve_hash2_suite_nu(alpha_string, salt, General_Check)

return cls.encode_to_curve_hash2_suite_ro(alpha_string, salt, General_Check)
else:
raise ValueError(f"Unexpected E2C Variant: {cls.curve.E2C}")

@classmethod
def encode_to_curve_hash2_suite_nu(
cls, alpha_string: bytes, salt: bytes = b"", General_Check: bool = False
) -> MGAffinePoint[C] | Any:
def encode_to_curve_hash2_suite_nu(cls, alpha_string: bytes, salt: bytes = b"", General_Check: bool = False) -> MGAffinePoint[C] | Any:
"""
Encode a string to a curve point using Elligator 2.

Expand All @@ -297,9 +289,7 @@ def encode_to_curve_hash2_suite_nu(
return R.clear_cofactor()

@classmethod
def encode_to_curve_hash2_suite_ro(
cls, alpha_string: bytes, salt: bytes = b"", General_Check: bool = False
) -> MGAffinePoint[C] | Any:
def encode_to_curve_hash2_suite_ro(cls, alpha_string: bytes, salt: bytes = b"", General_Check: bool = False) -> MGAffinePoint[C] | Any:
"""
Encode a string to a curve point using Elligator 2.

Expand Down Expand Up @@ -390,12 +380,8 @@ def point_to_string(self) -> bytes:
# Encode u and v coordinates as little-endian bytes
if self.x is None or self.y is None:
raise ValueError("Cannot serialize identity point")
x_bytes = int(cast(int, self.x)).to_bytes(
field_byte_len, cast(Literal["little", "big"], self.curve.ENDIAN)
)
y_bytes = int(cast(int, self.y)).to_bytes(
field_byte_len, cast(Literal["little", "big"], self.curve.ENDIAN)
)
x_bytes = int(cast(int, self.x)).to_bytes(field_byte_len, cast(Literal["little", "big"], self.curve.ENDIAN))
y_bytes = int(cast(int, self.y)).to_bytes(field_byte_len, cast(Literal["little", "big"], self.curve.ENDIAN))
return x_bytes + y_bytes
else:
raise NotImplementedError("Compressed encoding not implemented")
Expand Down
11 changes: 2 additions & 9 deletions dot_ring/curve/montgomery/mg_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,7 @@ def __eq__(self, other: object) -> bool:
"""Check if two curves are equal."""
if not isinstance(other, MGCurve):
return False
return (
self.PRIME_FIELD == other.PRIME_FIELD
and self.A == other.A
and self.B == other.B
)
return self.PRIME_FIELD == other.PRIME_FIELD and self.A == other.A and self.B == other.B

def __hash__(self) -> int:
"""Hash for use as dictionary keys."""
Expand All @@ -142,7 +138,4 @@ def __str__(self) -> str:

def __repr__(self) -> str:
"""Detailed string representation."""
return (
f"MGCurve(PRIME_FIELD={self.PRIME_FIELD}, A={self.A}, B={self.B}, "
f"equation: {self.B}*v² = u³ + {self.A}*u² + u)"
)
return f"MGCurve(PRIME_FIELD={self.PRIME_FIELD}, A={self.A}, B={self.B}, equation: {self.B}*v² = u³ + {self.A}*u² + u)"
Loading
Loading