diff --git a/turboquant/benchmark_google_parity.py b/turboquant/benchmark_google_parity.py index 8f108f6..cee574e 100644 --- a/turboquant/benchmark_google_parity.py +++ b/turboquant/benchmark_google_parity.py @@ -66,7 +66,7 @@ def compress(ks, li): rq, pk = compressors[li] flat = ks.reshape(-1, D).float() kq = triton_rotor_full_fused(flat, pk, None, - getattr(rq, 'centroids_vector'), None, getattr(rq, 'centroids_trivector')) + getattr(rq, 'centroids_vector'), None, None) return kq.to(ks.dtype).reshape(ks.shape) _orig = DynamicCache.update diff --git a/turboquant/benchmark_perplexity.py b/turboquant/benchmark_perplexity.py index 644ec2d..b24b91a 100644 --- a/turboquant/benchmark_perplexity.py +++ b/turboquant/benchmark_perplexity.py @@ -94,7 +94,7 @@ def compress(ks, li): flat, pk, None, getattr(rq, 'centroids_vector'), None, - getattr(rq, 'centroids_trivector'), + None, ) return kq.to(ks.dtype).reshape(ks.shape) diff --git a/turboquant/benchmark_triton.py b/turboquant/benchmark_triton.py index a9e32d0..1e84021 100644 --- a/turboquant/benchmark_triton.py +++ b/turboquant/benchmark_triton.py @@ -114,7 +114,7 @@ def verify_correctness(): c_scalar = None c_vector = getattr(rq, 'centroids_vector') c_bivector = None - c_trivector = getattr(rq, 'centroids_trivector') + c_trivector = None x_hat_triton = triton_rotor_full_fused( x, packed_rotors, c_scalar, c_vector, c_bivector, c_trivector) @@ -197,7 +197,7 @@ def benchmark_full_fused(): c_s = None c_v = getattr(rq, 'centroids_vector') c_b = None - c_t = getattr(rq, 'centroids_trivector') + c_t = None print(f" GPU: {torch.cuda.get_device_name()}") print(f" d={d}, bits={bits}\n") @@ -328,7 +328,7 @@ def benchmark_varying_dimensions(): c_s = None c_v = getattr(rq, 'centroids_vector') c_b = None - c_t = getattr(rq, 'centroids_trivector') + c_t = None n_groups = (d + 2) // 3 x = torch.randn(n, d, device=device) @@ -381,7 +381,7 @@ def benchmark_vs_turboquant(): c_s = None c_v = getattr(rq, 'centroids_vector') c_b = None - c_t = getattr(rq, 'centroids_trivector') + c_t = None for n in [1024, 4096, 16384]: x = torch.randn(n, d, device=device) @@ -437,7 +437,7 @@ def benchmark_bitwidth_sweep(): c_s = None c_v = getattr(rq, 'centroids_vector') c_b = None - c_t = getattr(rq, 'centroids_trivector') + c_t = None # Quality x_hat_pt, _ = rq(x) diff --git a/turboquant/poc_high_context.py b/turboquant/poc_high_context.py index 093e372..56de7a2 100644 --- a/turboquant/poc_high_context.py +++ b/turboquant/poc_high_context.py @@ -37,7 +37,6 @@ def __init__(self, head_dim: int, bits: int, seed: int, device: str): self.rq = RotorQuantMSE(head_dim, bits, seed=seed, device=device) self.packed_rotors = pack_rotors_for_triton(self.rq.rotors).to(device) self.c_v = getattr(self.rq, 'centroids_vector').to(device) - self.c_t = getattr(self.rq, 'centroids_trivector').to(device) self.head_dim = head_dim self.device = device @@ -59,7 +58,7 @@ def compress_dequantize(self, keys: torch.Tensor) -> torch.Tensor: # Triton fused: embed→rotor→quantize→unrotor→extract flat_recon = triton_rotor_full_fused( flat, self.packed_rotors, - None, self.c_v, None, self.c_t, + None, self.c_v, None, None, ) return flat_recon.to(orig_dtype).reshape(B, H, S, D) diff --git a/turboquant/triton_kernels.py b/turboquant/triton_kernels.py index 8a20924..7e654e8 100644 --- a/turboquant/triton_kernels.py +++ b/turboquant/triton_kernels.py @@ -289,7 +289,7 @@ def triton_rotor_full_fused( c_s = c_scalar.float().contiguous() if c_scalar is not None else c_vector.float().contiguous() c_v = c_vector.float().contiguous() c_b = c_bivector.float().contiguous() if c_bivector is not None else c_vector.float().contiguous() - c_t = c_trivector.float().contiguous() + c_t = c_trivector.float().contiguous() if c_trivector is not None else c_vector.float().contiguous() output = torch.empty_like(input_f32)