diff --git a/lib/std/crypto/ml_kem.zig b/lib/std/crypto/ml_kem.zig index 0a8e73f785e1..d48c61448e34 100644 --- a/lib/std/crypto/ml_kem.zig +++ b/lib/std/crypto/ml_kem.zig @@ -634,33 +634,11 @@ test "invNTTReductions bounds" { } } -// Extended euclidean algorithm. -// -// For a, b finds x, y such that x a + y b = gcd(a, b). Used to compute -// modular inverse. -fn eea(a: anytype, b: @TypeOf(a)) EeaResult(@TypeOf(a)) { - if (a == 0) { - return .{ .gcd = b, .x = 0, .y = 1 }; - } - const r = eea(@rem(b, a), a); - return .{ .gcd = r.gcd, .x = r.y - @divTrunc(b, a) * r.x, .y = r.x }; -} - -fn EeaResult(comptime T: type) type { - return struct { gcd: T, x: T, y: T }; -} - -// Returns least common multiple of a and b. -fn lcm(a: anytype, b: @TypeOf(a)) @TypeOf(a) { - const r = eea(a, b); - return a * b / r.gcd; -} - // Invert modulo p. fn invertMod(a: anytype, p: @TypeOf(a)) @TypeOf(a) { - const r = eea(a, p); + const r = std.math.egcd(a, p); assert(r.gcd == 1); - return r.x; + return r.bezout_coeff_1; } // Reduce mod q for testing. @@ -1054,7 +1032,7 @@ const Poly = struct { var in_off: usize = 0; var out_off: usize = 0; - const batch_size: usize = comptime lcm(@as(i16, d), 8); + const batch_size: usize = comptime std.math.lcm(@as(i16, d), 8); const in_batch_size: usize = comptime batch_size / d; const out_batch_size: usize = comptime batch_size / 8; @@ -1118,7 +1096,7 @@ const Poly = struct { var in_off: usize = 0; var out_off: usize = 0; - const batch_size: usize = comptime lcm(@as(i16, d), 8); + const batch_size: usize = comptime std.math.lcm(@as(i16, d), 8); const in_batch_size: usize = comptime batch_size / 8; const out_batch_size: usize = comptime batch_size / d; diff --git a/lib/std/math.zig b/lib/std/math.zig index c1b489a41d50..0a3ac9ad46ce 100644 --- a/lib/std/math.zig +++ b/lib/std/math.zig @@ -238,6 +238,7 @@ pub const sinh = @import("math/sinh.zig").sinh; pub const cosh = @import("math/cosh.zig").cosh; pub const tanh = @import("math/tanh.zig").tanh; pub const gcd = @import("math/gcd.zig").gcd; +pub const egcd = @import("math/egcd.zig").egcd; pub const lcm = @import("math/lcm.zig").lcm; pub const gamma = @import("math/gamma.zig").gamma; pub const lgamma = @import("math/gamma.zig").lgamma; diff --git a/lib/std/math/egcd.zig b/lib/std/math/egcd.zig new file mode 100644 index 000000000000..ec7471724e7a --- /dev/null +++ b/lib/std/math/egcd.zig @@ -0,0 +1,241 @@ +//! Extended Greatest Common Divisor (https://mathworld.wolfram.com/ExtendedGreatestCommonDivisor.html) +const std = @import("../std.zig"); + +/// Result type of `egcd`. +pub fn ExtendedGreatestCommonDivisor(S: anytype) type { + const N = switch (S) { + comptime_int => comptime_int, + else => |T| std.meta.Int(.unsigned, @bitSizeOf(T)), + }; + + return struct { + gcd: N, + bezout_coeff_1: S, + bezout_coeff_2: S, + }; +} + +/// Returns the Extended Greatest Common Divisor (EGCD) of two signed integers (`a` and `b`) which are not both zero. +pub fn egcd(a: anytype, b: anytype) ExtendedGreatestCommonDivisor(@TypeOf(a, b)) { + const S = switch (@TypeOf(a, b)) { + comptime_int => b: { + const n = @max(@abs(a), @abs(b)); + break :b std.math.IntFittingRange(-n, n); + }, + else => |T| T, + }; + if (@typeInfo(S) != .int or @typeInfo(S).int.signedness != .signed) { + @compileError("`a` and `b` must be signed integers"); + } + + std.debug.assert(a != 0 or b != 0); + + if (a == 0) return .{ .gcd = @abs(b), .bezout_coeff_1 = 0, .bezout_coeff_2 = std.math.sign(b) }; + if (b == 0) return .{ .gcd = @abs(a), .bezout_coeff_1 = std.math.sign(a), .bezout_coeff_2 = 0 }; + + const other: S, const odd: S, const shift, const switch_coeff = b: { + const xz = @ctz(@as(S, a)); + const yz = @ctz(@as(S, b)); + break :b if (xz < yz) .{ b, a, xz, true } else .{ a, b, yz, false }; + }; + const toinv = @shrExact(other, @intCast(shift)); + const ctrl = @shrExact(odd, @intCast(shift)); // Invariant: |s|, |t|, |ctrl| < |MIN_OF(S)| + const half_ctrl = 1 + @shrExact(ctrl - 1, 1); + const abs_ctrl = @abs(ctrl); + + var s: S = std.math.sign(toinv); + var t: S = 0; + + var x = @abs(toinv); + var y = abs_ctrl; + + { + const xz = @ctz(x); + x = @shrExact(x, @intCast(xz)); + for (0..xz) |_| { + const half_s = s >> 1; + if (s & 1 == 0) + s = half_s + else + s = half_s + half_ctrl; + } + } + + var y_minus_x = y -% x; + while (y_minus_x != 0) : (y_minus_x = y -% x) { + const t_minus_s = t - s; + const copy_x = x; + const copy_s = s; + const xz = @ctz(y_minus_x); + + s -= t; + const carry = x < y; + x -%= y; + if (carry) { + x = y_minus_x; + y = copy_x; + s = t_minus_s; + t = copy_s; + } + x = @shrExact(x, @intCast(xz)); + for (0..xz) |_| { + const half_s = s >> 1; + if (s & 1 == 0) + s = half_s + else + s = half_s + half_ctrl; + } + + if (s < 0) s = @intCast(abs_ctrl - @abs(s)); + } + + // Using integer widening is only a temporary solution. + const W = std.meta.Int(.signed, @bitSizeOf(S) * 2); + t = @intCast(@divExact(y - @as(W, s) * toinv, ctrl)); + const final_s, const final_t = if (switch_coeff) .{ t, s } else .{ s, t }; + return .{ + .gcd = @shlExact(y, @intCast(shift)), + .bezout_coeff_1 = final_s, + .bezout_coeff_2 = final_t, + }; +} + +test { + { + const a: i2 = 0; + const b: i2 = 1; + const r = egcd(a, b); + const g = r.gcd; + const s: i2 = r.bezout_coeff_1; + const t: i2 = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } + { + const a: i8 = -128; + const b: i8 = 127; + const r = egcd(a, b); + const g = r.gcd; + const s: i16 = r.bezout_coeff_1; + const t: i16 = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } + { + const a: i16 = -32768; + const b: i16 = -32768; + const r = egcd(a, b); + const g = r.gcd; + const s: i32 = r.bezout_coeff_1; + const t: i32 = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } + { + const a: i32 = 128; + const b: i32 = 112; + const r = egcd(a, b); + const g = r.gcd; + const s: i64 = r.bezout_coeff_1; + const t: i64 = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } + { + const a: i32 = 4 * 89; + const b: i32 = 2 * 17; + const r = egcd(a, b); + const g = r.gcd; + const s: i64 = r.bezout_coeff_1; + const t: i64 = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } + { + const a: i8 = 127; + const b: i8 = 126; + const r = egcd(a, b); + const g = r.gcd; + const s: i16 = r.bezout_coeff_1; + const t: i16 = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } + { + const a: i4 = -8; + const b: i4 = 1; + const r = egcd(a, b); + const g = r.gcd; + const s = r.bezout_coeff_1; + const t = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } + { + const a: i4 = -8; + const b: i4 = 5; + const r = egcd(a, b); + const g = r.gcd; + // Avoid overflow in assert. + const s: i8 = r.bezout_coeff_1; + const t: i8 = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } + { + const a: i32 = 0; + const b: i32 = 5; + const r = egcd(a, b); + const g = r.gcd; + const s = r.bezout_coeff_1; + const t = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } + { + const a: i32 = 5; + const b: i32 = 0; + const r = egcd(a, b); + const g = r.gcd; + const s = r.bezout_coeff_1; + const t = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } + + { + const a: i32 = 21; + const b: i32 = 15; + const r = egcd(a, b); + const g = r.gcd; + const s = r.bezout_coeff_1; + const t = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } + { + const a: i32 = -21; + const b: i32 = 15; + const r = egcd(a, b); + const g = r.gcd; + const s = r.bezout_coeff_1; + const t = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } + { + const a = -21; + const b = 15; + const r = egcd(a, b); + const g = r.gcd; + const s = r.bezout_coeff_1; + const t = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } + { + const a = 927372692193078999176; + const b = 573147844013817084101; + const r = egcd(a, b); + const g = r.gcd; + const s = r.bezout_coeff_1; + const t = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } + { + const a = 453973694165307953197296969697410619233826; + const b = 280571172992510140037611932413038677189525; + const r = egcd(a, b); + const g = r.gcd; + const s = r.bezout_coeff_1; + const t = r.bezout_coeff_2; + try std.testing.expect(s * a + t * b == g); + } +} diff --git a/lib/std/math/gcd.zig b/lib/std/math/gcd.zig index 16ca7846f19a..01c6b8b62914 100644 --- a/lib/std/math/gcd.zig +++ b/lib/std/math/gcd.zig @@ -1,7 +1,7 @@ -//! Greatest common divisor (https://mathworld.wolfram.com/GreatestCommonDivisor.html) -const std = @import("std"); +//! Greatest Common Divisor (https://mathworld.wolfram.com/GreatestCommonDivisor.html) +const std = @import("../std.zig"); -/// Returns the greatest common divisor (GCD) of two unsigned integers (`a` and `b`) which are not both zero. +/// Returns the Greatest Common Divisor (GCD) of two unsigned integers (`a` and `b`) which are not both zero. /// For example, the GCD of `8` and `12` is `4`, that is, `gcd(8, 12) == 4`. pub fn gcd(a: anytype, b: anytype) @TypeOf(a, b) { const N = switch (@TypeOf(a, b)) { @@ -9,12 +9,14 @@ pub fn gcd(a: anytype, b: anytype) @TypeOf(a, b) { comptime_int => std.math.IntFittingRange(@min(a, b), @max(a, b)), else => |T| T, }; + if (@typeInfo(N) != .int or @typeInfo(N).int.signedness != .unsigned) { - @compileError("`a` and `b` must be usigned integers"); + @compileError("`a` and `b` must be unsigned integers"); } // using an optimised form of Stein's algorithm: // https://en.wikipedia.org/wiki/Binary_GCD_algorithm + std.debug.assert(a != 0 or b != 0); if (a == 0) return b; @@ -26,25 +28,27 @@ pub fn gcd(a: anytype, b: anytype) @TypeOf(a, b) { const xz = @ctz(x); const yz = @ctz(y); const shift = @min(xz, yz); - x >>= @intCast(xz); - y >>= @intCast(yz); - - var diff = y -% x; - while (diff != 0) : (diff = y -% x) { - // ctz is invariant under negation, we - // put it here to ease data dependencies, - // makes the CPU happy. - const zeros = @ctz(diff); - if (x > y) diff = -%diff; - y = @min(x, y); - x = diff >> @intCast(zeros); + x = @shrExact(x, @intCast(xz)); + y = @shrExact(y, @intCast(yz)); + + var y_minus_x = y -% x; + while (y_minus_x != 0) : (y_minus_x = y -% x) { + const copy_x = x; + const zeros = @ctz(y_minus_x); + const carry = x < y; + x -%= y; + if (carry) { + x = y_minus_x; + y = copy_x; + } + x = @shrExact(x, @intCast(zeros)); } - return y << @intCast(shift); + + return @shlExact(y, @intCast(shift)); } test gcd { const expectEqual = std.testing.expectEqual; - try expectEqual(gcd(0, 5), 5); try expectEqual(gcd(5, 0), 5); try expectEqual(gcd(8, 12), 4);