diff --git a/tests/test_proof_with_indices.py b/tests/test_proof_with_indices.py new file mode 100644 index 0000000..f2a3ec1 --- /dev/null +++ b/tests/test_proof_with_indices.py @@ -0,0 +1,179 @@ +"""Tests for proof saving with indices functionality.""" +import pytest +import torch +import tempfile +import os +from toploc.poly import ( + ProofWithIndices, + build_proofs_with_indices, + save_proofs_with_indices, + load_proofs_with_indices, + proofs_with_indices_to_bytes, + proofs_with_indices_from_bytes, + verify_proofs_with_indices, +) +from toploc.C.csrc.poly import VerificationResult + + +@pytest.fixture +def sample_activations(): + torch.manual_seed(42) + DIM = 16 + a = [torch.randn(3, DIM, dtype=torch.bfloat16)] + for _ in range(3 * 2 + 1): + a.append(torch.randn(DIM, dtype=torch.bfloat16)) + return a + + +def test_build_proofs_with_indices(sample_activations): + """Test building proofs with indices.""" + proofs_with_indices = build_proofs_with_indices( + sample_activations, decode_batching_size=2, topk=5 + ) + assert isinstance(proofs_with_indices, list) + assert all(isinstance(pwi, ProofWithIndices) for pwi in proofs_with_indices) + assert len(proofs_with_indices) == 5 + + # Each proof should have indices + for pwi in proofs_with_indices: + assert len(pwi.indices) == 5 # topk=5 + assert all(isinstance(idx, int) for idx in pwi.indices) + + +def test_proof_with_indices_to_dict(): + """Test ProofWithIndices serialization to dict.""" + from toploc.C.csrc.poly import ProofPoly + + proof = ProofPoly([1, 2, 3], 65497) + pwi = ProofWithIndices(proof=proof, indices=[10, 20, 30]) + + d = pwi.to_dict() + assert "proof_base64" in d + assert "indices" in d + assert d["indices"] == [10, 20, 30] + + +def test_proof_with_indices_from_dict(): + """Test ProofWithIndices deserialization from dict.""" + from toploc.C.csrc.poly import ProofPoly + + proof = ProofPoly([1, 2, 3], 65497) + pwi = ProofWithIndices(proof=proof, indices=[10, 20, 30]) + + d = pwi.to_dict() + pwi2 = ProofWithIndices.from_dict(d) + + assert pwi2.proof.coeffs == pwi.proof.coeffs + assert pwi2.proof.modulus == pwi.proof.modulus + assert pwi2.indices == pwi.indices + + +def test_save_and_load_proofs_with_indices(sample_activations): + """Test saving and loading proofs with indices to/from file.""" + proofs_with_indices = build_proofs_with_indices( + sample_activations, decode_batching_size=2, topk=5 + ) + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + temp_path = f.name + + try: + save_proofs_with_indices(proofs_with_indices, temp_path) + loaded = load_proofs_with_indices(temp_path) + + assert len(loaded) == len(proofs_with_indices) + for orig, load in zip(proofs_with_indices, loaded): + assert orig.proof.coeffs == load.proof.coeffs + assert orig.proof.modulus == load.proof.modulus + assert orig.indices == load.indices + finally: + os.unlink(temp_path) + + +def test_proofs_with_indices_to_and_from_bytes(sample_activations): + """Test serializing and deserializing proofs with indices to bytes.""" + proofs_with_indices = build_proofs_with_indices( + sample_activations, decode_batching_size=2, topk=5 + ) + + data = proofs_with_indices_to_bytes(proofs_with_indices) + assert isinstance(data, bytes) + + loaded = proofs_with_indices_from_bytes(data) + + assert len(loaded) == len(proofs_with_indices) + for orig, load in zip(proofs_with_indices, loaded): + assert orig.proof.coeffs == load.proof.coeffs + assert orig.proof.modulus == load.proof.modulus + assert orig.indices == load.indices + + +def test_verify_proofs_with_indices(sample_activations): + """Test verification of proofs using stored indices.""" + proofs_with_indices = build_proofs_with_indices( + sample_activations, decode_batching_size=3, topk=4 + ) + + # Verify with same activations (should pass) + results = verify_proofs_with_indices( + sample_activations, + proofs_with_indices, + decode_batching_size=3, + topk=4, + ) + + assert isinstance(results, list) + assert all(isinstance(r, VerificationResult) for r in results) + assert len(results) == len(proofs_with_indices) + assert all(r.exp_mismatches == 0 for r in results) + assert all(r.mant_err_mean == 0 for r in results) + assert all(r.mant_err_median == 0 for r in results) + + +def test_verify_proofs_with_indices_slightly_modified(sample_activations): + """Test verification with slightly modified activations.""" + proofs_with_indices = build_proofs_with_indices( + sample_activations, decode_batching_size=3, topk=4 + ) + + # Slightly modify activations + modified = [a * 1.01 for a in sample_activations] + + results = verify_proofs_with_indices( + modified, + proofs_with_indices, + decode_batching_size=3, + topk=4, + ) + + assert isinstance(results, list) + assert all(isinstance(r, VerificationResult) for r in results) + # Should have small errors due to modification + assert all(r.exp_mismatches == 0 for r in results) + assert all(0 < r.mant_err_mean <= 2 for r in results) + + +def test_build_proofs_with_indices_skip_prefill(sample_activations): + """Test building proofs with indices and skip_prefill.""" + proofs_with_indices = build_proofs_with_indices( + sample_activations[1:], decode_batching_size=2, topk=5, skip_prefill=True + ) + assert isinstance(proofs_with_indices, list) + assert all(isinstance(pwi, ProofWithIndices) for pwi in proofs_with_indices) + assert len(proofs_with_indices) == 4 + + +def test_build_proofs_with_indices_error_handling(): + """Test error handling in proof building with indices.""" + invalid_activations = [ + torch.randn(0, 16, dtype=torch.bfloat16), + torch.randn(16, dtype=torch.bfloat16), + ] + proofs_with_indices = build_proofs_with_indices( + invalid_activations, decode_batching_size=2, topk=5 + ) + assert isinstance(proofs_with_indices, list) + # Should return null proofs with empty indices + for pwi in proofs_with_indices: + assert pwi.proof.modulus == 0 + assert pwi.indices == [] diff --git a/toploc/__init__.py b/toploc/__init__.py index 6037aa9..2fbeff1 100644 --- a/toploc/__init__.py +++ b/toploc/__init__.py @@ -1,12 +1,19 @@ # ruff: noqa: F401 from toploc.poly import ( ProofPoly, + ProofWithIndices, build_proofs, build_proofs_bytes, build_proofs_base64, + build_proofs_with_indices, verify_proofs, verify_proofs_bytes, verify_proofs_base64, + verify_proofs_with_indices, + save_proofs_with_indices, + load_proofs_with_indices, + proofs_with_indices_to_bytes, + proofs_with_indices_from_bytes, ) from toploc.utils import sha256sum diff --git a/toploc/poly.py b/toploc/poly.py index a706763..433a4b7 100644 --- a/toploc/poly.py +++ b/toploc/poly.py @@ -12,6 +12,9 @@ import torch import logging from statistics import mean, median +from typing import Dict, Any, List +from dataclasses import dataclass +import json logger = logging.getLogger(__name__) @@ -187,3 +190,219 @@ def verify_proofs_base64( topk, skip_prefill, ) + + +# New functions for saving/loading proofs with indices + + +@dataclass +class ProofWithIndices: + """Container for a proof polynomial along with its indices.""" + proof: ProofPoly + indices: List[int] + + def to_dict(self) -> Dict[str, Any]: + """Convert to a dictionary for serialization.""" + return { + "proof_base64": self.proof.to_base64(), + "indices": self.indices, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ProofWithIndices": + """Create from a dictionary.""" + return cls( + proof=ProofPoly.from_base64(data["proof_base64"]), + indices=data["indices"], + ) + + +def build_proofs_with_indices( + activations: list[torch.Tensor], + decode_batching_size: int, + topk: int, + skip_prefill: bool = False, +) -> list[ProofWithIndices]: + """Build proofs along with their corresponding indices. + + This function returns ProofWithIndices objects that contain both the proof + polynomial and the indices used to create it. This allows saving and later + verifying proofs without needing the original activations. + + Args: + activations: List of activation tensors + decode_batching_size: Number of activations to batch together + topk: Number of top activations to consider + skip_prefill: Whether to skip the prefill (first) activation + + Returns: + List of ProofWithIndices objects + """ + proofs_with_indices = [] + + try: + # Prefill + if not skip_prefill: + flat_view = activations[0].view(-1) + topk_indices = flat_view.abs().topk(topk).indices + topk_values = flat_view[topk_indices] + proof = ProofPoly.from_points_tensor(topk_indices, topk_values) + proofs_with_indices.append(ProofWithIndices( + proof=proof, + indices=topk_indices.tolist() + )) + + # Batched Decode + for i in range( + 0 if skip_prefill else 1, len(activations), decode_batching_size + ): + flat_view = torch.cat( + [i.view(-1) for i in activations[i : i + decode_batching_size]] + ) + topk_indices = flat_view.abs().topk(topk).indices + topk_values = flat_view[topk_indices] + topk_indices = topk_indices.to("cpu") + topk_values = topk_values.to("cpu") + proof = ProofPoly.from_points_tensor(topk_indices, topk_values) + proofs_with_indices.append(ProofWithIndices( + proof=proof, + indices=topk_indices.tolist() + )) + except Exception as e: + logger.error(f"Error building proofs with indices: {e}") + # Return null proofs with empty indices on error + num_proofs = 1 + (len(activations) - 1 + decode_batching_size) // decode_batching_size + proofs_with_indices = [ + ProofWithIndices(proof=ProofPoly.null(topk), indices=[]) + for _ in range(num_proofs) + ] + + return proofs_with_indices + + +def save_proofs_with_indices( + proofs_with_indices: list[ProofWithIndices], + file_path: str, +) -> None: + """Save proofs with indices to a JSON file. + + Args: + proofs_with_indices: List of ProofWithIndices objects + file_path: Path to the output JSON file + """ + data = [pwi.to_dict() for pwi in proofs_with_indices] + with open(file_path, "w") as f: + json.dump(data, f) + + +def load_proofs_with_indices(file_path: str) -> list[ProofWithIndices]: + """Load proofs with indices from a JSON file. + + Args: + file_path: Path to the input JSON file + + Returns: + List of ProofWithIndices objects + """ + with open(file_path, "r") as f: + data = json.load(f) + return [ProofWithIndices.from_dict(item) for item in data] + + +def proofs_with_indices_to_bytes( + proofs_with_indices: list[ProofWithIndices], +) -> bytes: + """Serialize proofs with indices to bytes. + + Args: + proofs_with_indices: List of ProofWithIndices objects + + Returns: + Serialized bytes containing all proofs and indices + """ + data = [pwi.to_dict() for pwi in proofs_with_indices] + return json.dumps(data).encode("utf-8") + + +def proofs_with_indices_from_bytes(data: bytes) -> list[ProofWithIndices]: + """Deserialize proofs with indices from bytes. + + Args: + data: Serialized bytes + + Returns: + List of ProofWithIndices objects + """ + items = json.loads(data.decode("utf-8")) + return [ProofWithIndices.from_dict(item) for item in items] + + +def verify_proofs_with_indices( + activations: torch.Tensor, + proofs_with_indices: list[ProofWithIndices], + decode_batching_size: int, + topk: int, + skip_prefill: bool = False, +) -> list[VerificationResult]: + """Verify proofs using stored indices. + + This function verifies proofs using the indices stored in ProofWithIndices + objects instead of recomputing them from the activations. This is useful + when you have saved proofs and want to verify them against different activations. + + Args: + activations: Activation tensor (sequence_length, hidden_size) + proofs_with_indices: List of ProofWithIndices objects with stored indices + decode_batching_size: Number of activations to batch together + topk: Number of top activations to consider + skip_prefill: Whether to skip the prefill activation + + Returns: + List of VerificationResult objects + """ + results = [] + + batches = batch_activations( + activations, + decode_batching_size=decode_batching_size, + skip_prefill=skip_prefill, + ) + + for proof_with_idx, chunk in zip(proofs_with_indices, batches): + proof = proof_with_idx.proof + stored_indices = proof_with_idx.indices + + if not stored_indices: + # No indices stored, can't verify + results.append(VerificationResult(0, 2**64, 2**64)) + continue + + chunk = chunk.view(-1).cpu() + + # Use stored indices to get values + stored_indices_tensor = torch.tensor(stored_indices, dtype=torch.long) + topk_values = chunk[stored_indices_tensor] + + # Evaluate polynomial at stored indices + y_values = evaluate_polynomials(proof.coeffs, stored_indices) + proof_topk_values = torch.tensor(y_values, dtype=torch.uint16).view( + dtype=torch.bfloat16 + ) + + exps, mants = get_fp_parts(proof_topk_values) + proof_exps, proof_mants = get_fp_parts(topk_values) + + exp_mismatches = [i != j for i, j in zip(exps, proof_exps)] + mant_errs = [ + abs(i - j) for i, j, k in zip(mants, proof_mants, exp_mismatches) if not k + ] + if len(mant_errs) > 0: + results.append( + VerificationResult( + sum(exp_mismatches), mean(mant_errs), median(mant_errs) + ) + ) + else: + results.append(VerificationResult(sum(exp_mismatches), 2**64, 2**64)) + + return results