diff --git a/examples/notebooks/age_verification_optimize.ipynb b/examples/notebooks/age_verification_optimize.ipynb new file mode 100644 index 000000000..3f48e93a6 --- /dev/null +++ b/examples/notebooks/age_verification_optimize.ipynb @@ -0,0 +1,510 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import time\n", + "import zipfile\n", + "import cv2\n", + "import numpy as np\n", + "import onnxruntime as ort\n", + "import torch\n", + "import requests\n", + "import tempfile\n", + "import subprocess\n", + "import json\n", + "from tqdm import tqdm\n", + "from typing import Tuple, List, Optional\n", + "import random\n", + "\n", + "ort.set_default_logger_severity(3)\n", + "\n", + "# Model\n", + "MODEL_URL = 'https://zkevm-4.s3.us-east-2.amazonaws.com/model'\n", + "MODEL_NAME = 'age.onnx'\n", + "DOWNLOAD_MODEL_PATH = './competition/model/'\n", + "\n", + "# Test image datasets\n", + "IMAGE_DATASETS_URL = 'https://storage.omron.ai/age.zip'\n", + "DOWNLOAD_IMAGE_DATASETS = './competition/test_data/'\n", + "\n", + "# Circuit\n", + "CIRCUIT_BASE_URL = 'https://zkevm-4.s3.us-east-2.amazonaws.com/'\n", + "CIRCUIT_NAMES = ['circuit1','circuit2','circuit3']\n", + "# CIRCUIT_NAMES = ['circuit1','circuit2']\n", + "DOWNLOAD_CIRCUIT_PATH = './competition/circuit/'\n", + "\n", + "LOCAL_EZKL_PATH = '../../target/release/ezkl'\n", + "TEMP_FOLDER = './competition/tmp'\n", + "\n", + "TEST_COUNT = 10\n", + "\n", + "def download_model(url: str, save_directory: str, model_name) -> None:\n", + " save_path = os.path.join(save_directory, model_name)\n", + "\n", + " if os.path.exists(save_path):\n", + " print(f\"{save_path} already exists. Skipping download.\")\n", + " return\n", + "\n", + " os.makedirs(save_directory, exist_ok=True)\n", + "\n", + " try:\n", + " response = requests.get(url, stream=True)\n", + " response.raise_for_status()\n", + " total_size = int(response.headers.get('content-length', 0))\n", + "\n", + " with open(save_path, 'wb') as file, tqdm(\n", + " desc=model_name,\n", + " total=total_size,\n", + " unit='iB',\n", + " unit_scale=True,\n", + " unit_divisor=1024,\n", + " ) as bar:\n", + " for data in response.iter_content(chunk_size=1024):\n", + " file.write(data)\n", + " bar.update(len(data))\n", + "\n", + " print(f\"Downloaded {model_name} to {save_path}\")\n", + "\n", + " except requests.exceptions.RequestException as e:\n", + " print(f\"Error downloading {model_name}: {e}\")\n", + "\n", + "def download_circuit_files(urls, base_directory)-> list:\n", + " download_circuits_path = []\n", + " for url in urls:\n", + " circuit_name = url.split('/')[-2]\n", + " circuit_folder = os.path.join(base_directory, circuit_name)\n", + "\n", + " if not os.path.exists(circuit_folder):\n", + " os.makedirs(circuit_folder)\n", + "\n", + " all_files_exist = True\n", + " for file_name in ['kzg.srs', 'model.compiled', 'pk.key', 'settings.json', 'vk.key']:\n", + " save_path = os.path.join(circuit_folder, file_name)\n", + " if os.path.exists(save_path):\n", + " print(f\"{save_path} already exists. Skipping download.\")\n", + " else:\n", + " all_files_exist = False\n", + "\n", + " if all_files_exist:\n", + " print(f\"All files for {circuit_name} already exist. Skipping download.\")\n", + " print(f\"download_circuits_path:{download_circuits_path}\")\n", + " download_circuits_path.append(circuit_folder)\n", + " continue\n", + "\n", + " for file_name in ['kzg.srs', 'model.compiled', 'pk.key', 'settings.json', 'vk.key']:\n", + " file_url = f\"{url}{file_name}\"\n", + " save_path = os.path.join(circuit_folder, file_name)\n", + "\n", + " response = requests.get(file_url, stream=True)\n", + " total_size = int(response.headers.get('content-length', 0))\n", + "\n", + " with open(save_path, 'wb') as file, tqdm(\n", + " desc=file_name,\n", + " total=total_size,\n", + " unit='iB',\n", + " unit_scale=True,\n", + " unit_divisor=1024,\n", + " ) as bar:\n", + " for data in response.iter_content(chunk_size=1024):\n", + " file.write(data)\n", + " bar.update(len(data))\n", + "\n", + " print(f\"Downloaded {file_name} to {save_path}\")\n", + " download_circuits_path.append(circuit_folder)\n", + " print(f\"download_circuits_path:{download_circuits_path}\")\n", + " print(f\"All files are downloaded succefully!\")\n", + "\n", + " return download_circuits_path\n", + "\n", + "def download_and_process_images(download_image_dataset, image_datasets_url)-> str:\n", + " if os.path.exists(download_image_dataset):\n", + " print(f\"Test datasets dir {download_image_dataset} already exists!\")\n", + " else:\n", + " print(f\"Test datasets dir {download_image_dataset} does not exist, creating it.\")\n", + " os.makedirs(download_image_dataset)\n", + "\n", + " zip_path = os.path.join(download_image_dataset, \"age.zip\")\n", + " extracted_path = os.path.join(download_image_dataset, \"extracted\")\n", + " processed_path = os.path.join(download_image_dataset, \"processed_64x64\")\n", + "\n", + " os.makedirs(extracted_path, exist_ok=True)\n", + " os.makedirs(processed_path, exist_ok=True)\n", + "\n", + " if os.path.exists(processed_path) and os.listdir(processed_path):\n", + " print(f\"{processed_path} already exists and is not empty. Skipping processing.\")\n", + " return processed_path\n", + "\n", + " print(\"Downloading dataset...\")\n", + " response = requests.get(image_datasets_url, stream=True)\n", + " total_size = int(response.headers.get(\"content-length\", 0))\n", + "\n", + " with open(zip_path, \"wb\") as f, tqdm(\n", + " desc=\"Downloading\", total=total_size, unit=\"iB\", unit_scale=True\n", + " ) as pbar:\n", + " for data in response.iter_content(chunk_size=1024):\n", + " size = f.write(data)\n", + " pbar.update(size)\n", + "\n", + " print(\"Extracting zip...\")\n", + " with zipfile.ZipFile(zip_path, \"r\") as zip_ref:\n", + " zip_ref.extractall(extracted_path)\n", + "\n", + " print(\"Processing images to 64x64...\")\n", + " for root, _, files in tqdm(os.walk(extracted_path)):\n", + " for img_name in files:\n", + " if img_name.lower().endswith((\".png\", \".jpg\", \".jpeg\")):\n", + " img_path = os.path.join(root, img_name)\n", + " try:\n", + " img = cv2.imread(img_path)\n", + " if img is not None:\n", + " img = cv2.resize(img, (64, 64))\n", + " cv2.imwrite(os.path.join(processed_path, img_name), img)\n", + " except Exception as e:\n", + " print(f\"Failed to process {img_name}: {e}\")\n", + "\n", + " print(f\"Images processed and saved to {processed_path}\")\n", + " return processed_path\n", + "\n", + "class ImageProcessor:\n", + " @staticmethod\n", + " def normalize(\n", + " img: torch.Tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n", + " ) -> torch.Tensor:\n", + " mean = torch.tensor(mean).view(-1, 1, 1)\n", + " std = torch.tensor(std).view(-1, 1, 1)\n", + " return (img - mean) / std\n", + "\n", + " @staticmethod\n", + " def to_tensor(img: np.ndarray) -> torch.Tensor:\n", + " img = img.transpose((2, 0, 1)) # HWC to CHW\n", + " img = torch.from_numpy(img).float()\n", + " return img / 255.0\n", + "\n", + "def preprocess_image(img_path) -> Optional[torch.Tensor]:\n", + " try:\n", + " img = cv2.imread(img_path)\n", + " if img is None:\n", + " print(f\" Error cat not fild: {img_path}\")\n", + " return None\n", + "\n", + " image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", + " # Convert to tensor and normalize\n", + " tensor = ImageProcessor.to_tensor(image)\n", + " tensor = ImageProcessor.normalize(tensor)\n", + "\n", + " # Add batch dimension\n", + " tensor = tensor.unsqueeze(0)\n", + "\n", + " return tensor\n", + " except Exception as e:\n", + " print(f\" Error processing image data: {e}\")\n", + " return None\n", + "\n", + "def inference_onnx_model(model_path: str, input_tensor: torch.Tensor) -> torch.Tensor:\n", + " try:\n", + " session = ort.InferenceSession(model_path)\n", + " input_name = session.get_inputs()[0].name\n", + "\n", + " # Convert input tensor to numpy array\n", + " input_data = input_tensor.detach().cpu().numpy() if input_tensor.requires_grad else input_tensor.cpu().numpy()\n", + " # Prepare input feed\n", + " options = ort.RunOptions()\n", + " options.log_severity_level = 3\n", + "\n", + " output_names = [output.name for output in session.get_outputs()]\n", + " outputs = session.run(output_names, {input_name: input_data}, options)\n", + "\n", + " return outputs\n", + "\n", + " except Exception as e:\n", + " print(f\"Error during ONNX inference: {e}\")\n", + " return None\n", + "\n", + "def get_temp_folder() -> str:\n", + " if not os.path.exists(TEMP_FOLDER):\n", + " os.makedirs(TEMP_FOLDER, exist_ok=True)\n", + " return TEMP_FOLDER\n", + "\n", + "def generate_proof(\n", + " circuit_dir: str, test_inputs: torch.Tensor\n", + ") -> Tuple[str, dict] | None:\n", + " try:\n", + " input_data = {\n", + " \"input_data\": [[float(x) for x in test_inputs.flatten().tolist()]]\n", + " }\n", + "\n", + " with tempfile.NamedTemporaryFile(\n", + " mode=\"w+\", suffix=\".json\", dir=get_temp_folder(), delete=False\n", + " ) as temp_input:\n", + " json.dump(input_data, temp_input, indent=2)\n", + " temp_input_path = temp_input.name\n", + "\n", + " with tempfile.NamedTemporaryFile(\n", + " mode=\"w+\", suffix=\".json\", dir=get_temp_folder(), delete=False\n", + " ) as temp_witness:\n", + " witness_path = temp_witness.name\n", + "\n", + " with tempfile.NamedTemporaryFile(\n", + " mode=\"w+\", suffix=\".json\", dir=get_temp_folder(), delete=False\n", + " ) as temp_proof:\n", + " temp_proof_path = temp_proof.name\n", + "\n", + " model_path = os.path.join(circuit_dir, \"model.compiled\")\n", + " if not os.path.exists(model_path):\n", + " print(f\"model.compiled not found at {model_path}\")\n", + " return None\n", + "\n", + " # print(f\"Input data: {json.dumps(input_data, indent=2)}\")\n", + " witness_result = subprocess.run(\n", + " [\n", + " LOCAL_EZKL_PATH,\n", + " \"gen-witness\",\n", + " \"--data\",\n", + " temp_input_path,\n", + " \"--compiled-circuit\",\n", + " model_path,\n", + " \"--output\",\n", + " witness_path,\n", + " ],\n", + " capture_output=True,\n", + " text=True,\n", + " timeout=300,\n", + " )\n", + "\n", + " if witness_result.returncode != 0:\n", + " print(\n", + " f\"Witness generation failed with code {witness_result.returncode}\"\n", + " )\n", + " print(f\"STDOUT: {witness_result.stdout}\")\n", + " print(f\"STDERR: {witness_result.stderr}\")\n", + " return None\n", + "\n", + " # print(\"Witness generation successful, starting proof generation\")\n", + " proof_start = time.perf_counter()\n", + " prove_result = subprocess.run(\n", + " [\n", + " LOCAL_EZKL_PATH,\n", + " \"prove\",\n", + " \"--compiled-circuit\",\n", + " model_path,\n", + " \"--witness\",\n", + " witness_path,\n", + " \"--pk-path\",\n", + " os.path.join(circuit_dir, \"pk.key\"),\n", + " \"--proof-path\",\n", + " temp_proof_path,\n", + " ],\n", + " capture_output=True,\n", + " text=True,\n", + " timeout=300,\n", + " )\n", + " proof_time = time.perf_counter() - proof_start\n", + "\n", + " os.unlink(temp_input_path)\n", + " os.unlink(witness_path)\n", + "\n", + " if prove_result.returncode != 0:\n", + " print(\n", + " f\"Proof generation failed with code {prove_result.returncode}\"\n", + " )\n", + " print(f\"STDOUT: {prove_result.stdout}\")\n", + " print(f\"STDERR: {prove_result.stderr}\")\n", + " return None\n", + "\n", + " with open(temp_proof_path) as f:\n", + " proof_data = json.load(f)\n", + " # print(f\"Proof timing - Proof: {proof_time:.3f}s\")\n", + " return temp_proof_path, proof_data, proof_time\n", + " except Exception as e:\n", + " print(f\"Error generating proof: {e}\")\n", + " return None\n", + "\n", + "def verify_proof(circuit_dir: str, proof_path: str) -> bool:\n", + " try:\n", + " verify_result = subprocess.run(\n", + " [\n", + " LOCAL_EZKL_PATH,\n", + " \"verify\",\n", + " \"--proof-path\",\n", + " proof_path,\n", + " \"--settings-path\",\n", + " os.path.join(circuit_dir, \"settings.json\"),\n", + " \"--vk-path\",\n", + " os.path.join(circuit_dir, \"vk.key\"),\n", + " ],\n", + " capture_output=True,\n", + " text=True,\n", + " timeout=300,\n", + " )\n", + " return verify_result.returncode == 0\n", + " except Exception as e:\n", + " print(f\"Error verifying proof: {e}\")\n", + " return False\n", + " finally:\n", + " if os.path.exists(proof_path):\n", + " os.unlink(proof_path)\n", + "\n", + "def compare_outputs(expected: list, actual: list) -> float:\n", + " try:\n", + " expected_tensor = torch.tensor(expected)\n", + " actual_tensor = torch.tensor(actual)\n", + "\n", + " expected_flat = expected_tensor.flatten()\n", + " actual_flat = actual_tensor.flatten()\n", + "\n", + " mae = torch.nn.functional.l1_loss(actual_flat, expected_flat)\n", + " raw_accuracy = torch.exp(-mae).item()\n", + " return raw_accuracy\n", + " except Exception as e:\n", + " return 0.0\n", + "\n", + "def benchmark(onnx_model_path, processed_path, circuit_dir, test_count)-> Tuple[float, float, bool, dict]:\n", + " image_files = [f for f in os.listdir(processed_path) if f.lower().endswith((\".png\", \".jpg\", \".jpeg\"))]\n", + " # print(f\"Total images found: {len(image_files)}\")\n", + "\n", + " selected_images = random.sample(image_files, min(test_count, len(image_files)))\n", + " print(f\"Selected images: {selected_images}\")\n", + "\n", + " raw_accuracy_scores, proof_sizes, response_times, verification_results = (\n", + " [],\n", + " [],\n", + " [],\n", + " [],\n", + " )\n", + " for img_name in selected_images:\n", + " image_path = os.path.join(processed_path, img_name)\n", + " input_tensor = preprocess_image(image_path)\n", + "\n", + " output_tensor1 = inference_onnx_model(onnx_model_path, input_tensor)\n", + "\n", + " flattened = []\n", + " for out in output_tensor1:\n", + " flattened.extend(out.flatten())\n", + " baseline_output = np.array(flattened)\n", + "\n", + " proof_result = generate_proof(circuit_dir, input_tensor)\n", + " if not proof_result:\n", + " print(\"Proof generation failed\")\n", + " raw_accuracy_scores.append(0.0)\n", + " verification_results.append(False)\n", + " proof_sizes.append(float(\"inf\"))\n", + " response_times.append(float(\"inf\"))\n", + " continue\n", + "\n", + " proof_path, proof_data, response_time = proof_result\n", + " # print(\n", + " # f\"Generated proof with size: {len(proof_data['proof'])}\"\n", + " # )\n", + " response_times.append(response_time)\n", + "\n", + " proof = proof_data.get(\"proof\", [])\n", + " public_signals = [\n", + " float(x)\n", + " for sublist in proof_data.get(\"pretty_public_inputs\", {}).get(\n", + " \"rescaled_outputs\", []\n", + " )\n", + " for x in sublist\n", + " ]\n", + " proof_sizes.append(len(proof))\n", + " \n", + " verify_result = verify_proof(circuit_dir, proof_path)\n", + " # print(f\"Proof verification result: {verify_result}\")\n", + " verification_results.append(verify_result)\n", + "\n", + " if verify_result:\n", + " raw_accuracy = compare_outputs(\n", + " baseline_output, public_signals\n", + " )\n", + " # print(f\"Raw accuracy: {raw_accuracy}\")\n", + " raw_accuracy_scores.append(raw_accuracy)\n", + " else:\n", + " print(\"Proof verification failed\")\n", + " raw_accuracy_scores.append(0.0)\n", + "\n", + " if not all(verification_results):\n", + " print(\n", + " \"One or more verifications failed - setting all scores to 0\"\n", + " )\n", + " return 0.0, float(\"inf\"), float(\"inf\"), False, {}\n", + "\n", + " avg_raw_accuracy = (\n", + " sum(raw_accuracy_scores) / len(raw_accuracy_scores)\n", + " if raw_accuracy_scores\n", + " else 0\n", + " )\n", + " avg_proof_size = (\n", + " sum(proof_sizes) / len(proof_sizes) if proof_sizes else float(\"inf\")\n", + " )\n", + " avg_response_time = (\n", + " sum(response_times) / len(response_times)\n", + " if response_times\n", + " else float(\"inf\")\n", + " )\n", + "\n", + " return (\n", + " avg_proof_size,\n", + " avg_response_time,\n", + " True,\n", + " avg_raw_accuracy,\n", + " )\n", + "\n", + "def main():\n", + " print(\"** Age Verification Optimize Test **\")\n", + "\n", + " print(\"** Downloading model... **\")\n", + " model_path = DOWNLOAD_MODEL_PATH + MODEL_NAME\n", + " download_model(MODEL_URL, DOWNLOAD_MODEL_PATH, MODEL_NAME)\n", + "\n", + " print(\"** Downloading test image datasets... **\")\n", + " processed_path = download_and_process_images(DOWNLOAD_IMAGE_DATASETS, IMAGE_DATASETS_URL)\n", + "\n", + " print(\"** Downloading circuit files... **\")\n", + " circuit_urls = [f\"{CIRCUIT_BASE_URL}{name}/\" for name in CIRCUIT_NAMES]\n", + " download_circuits_path = download_circuit_files(circuit_urls, DOWNLOAD_CIRCUIT_PATH)\n", + "\n", + " print(\"** Start to benchmark... **\")\n", + " for circuit_path in download_circuits_path:\n", + " print(f\"** Running: {circuit_path}... **\")\n", + " avg_proof_size, avg_response_time, verification_results, avg_raw_accuracy = benchmark(model_path, processed_path, circuit_path, TEST_COUNT)\n", + " print(\n", + " f\"** {circuit_path} Result: \\n\"\n", + " f\" - avg_proof_size: {avg_proof_size}\\n\"\n", + " f\" - avg_response_time: {avg_response_time}\\n\"\n", + " f\" - verification_results: {verification_results}\\n\"\n", + " f\" - avg_raw_accuracy: {avg_raw_accuracy}\"\n", + " )\n", + " print(\"** Benchmark Completed... **\")\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}