Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion turboquant/benchmark_google_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def compress(ks, li):
rq, pk = compressors[li]
flat = ks.reshape(-1, D).float()
kq = triton_rotor_full_fused(flat, pk, None,
getattr(rq, 'centroids_vector'), None, getattr(rq, 'centroids_trivector'))
getattr(rq, 'centroids_vector'), None, None)
return kq.to(ks.dtype).reshape(ks.shape)

_orig = DynamicCache.update
Expand Down
2 changes: 1 addition & 1 deletion turboquant/benchmark_perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def compress(ks, li):
flat, pk, None,
getattr(rq, 'centroids_vector'),
None,
getattr(rq, 'centroids_trivector'),
None,
)
return kq.to(ks.dtype).reshape(ks.shape)

Expand Down
10 changes: 5 additions & 5 deletions turboquant/benchmark_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def verify_correctness():
c_scalar = None
c_vector = getattr(rq, 'centroids_vector')
c_bivector = None
c_trivector = getattr(rq, 'centroids_trivector')
c_trivector = None

x_hat_triton = triton_rotor_full_fused(
x, packed_rotors, c_scalar, c_vector, c_bivector, c_trivector)
Expand Down Expand Up @@ -197,7 +197,7 @@ def benchmark_full_fused():
c_s = None
c_v = getattr(rq, 'centroids_vector')
c_b = None
c_t = getattr(rq, 'centroids_trivector')
c_t = None

print(f" GPU: {torch.cuda.get_device_name()}")
print(f" d={d}, bits={bits}\n")
Expand Down Expand Up @@ -328,7 +328,7 @@ def benchmark_varying_dimensions():
c_s = None
c_v = getattr(rq, 'centroids_vector')
c_b = None
c_t = getattr(rq, 'centroids_trivector')
c_t = None
n_groups = (d + 2) // 3

x = torch.randn(n, d, device=device)
Expand Down Expand Up @@ -381,7 +381,7 @@ def benchmark_vs_turboquant():
c_s = None
c_v = getattr(rq, 'centroids_vector')
c_b = None
c_t = getattr(rq, 'centroids_trivector')
c_t = None

for n in [1024, 4096, 16384]:
x = torch.randn(n, d, device=device)
Expand Down Expand Up @@ -437,7 +437,7 @@ def benchmark_bitwidth_sweep():
c_s = None
c_v = getattr(rq, 'centroids_vector')
c_b = None
c_t = getattr(rq, 'centroids_trivector')
c_t = None

# Quality
x_hat_pt, _ = rq(x)
Expand Down
3 changes: 1 addition & 2 deletions turboquant/poc_high_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(self, head_dim: int, bits: int, seed: int, device: str):
self.rq = RotorQuantMSE(head_dim, bits, seed=seed, device=device)
self.packed_rotors = pack_rotors_for_triton(self.rq.rotors).to(device)
self.c_v = getattr(self.rq, 'centroids_vector').to(device)
self.c_t = getattr(self.rq, 'centroids_trivector').to(device)
self.head_dim = head_dim
self.device = device

Expand All @@ -59,7 +58,7 @@ def compress_dequantize(self, keys: torch.Tensor) -> torch.Tensor:
# Triton fused: embed→rotor→quantize→unrotor→extract
flat_recon = triton_rotor_full_fused(
flat, self.packed_rotors,
None, self.c_v, None, self.c_t,
None, self.c_v, None, None,
)

return flat_recon.to(orig_dtype).reshape(B, H, S, D)
Expand Down
2 changes: 1 addition & 1 deletion turboquant/triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def triton_rotor_full_fused(
c_s = c_scalar.float().contiguous() if c_scalar is not None else c_vector.float().contiguous()
c_v = c_vector.float().contiguous()
c_b = c_bivector.float().contiguous() if c_bivector is not None else c_vector.float().contiguous()
c_t = c_trivector.float().contiguous()
c_t = c_trivector.float().contiguous() if c_trivector is not None else c_vector.float().contiguous()

output = torch.empty_like(input_f32)

Expand Down